Exemple #1
0
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("albert-base" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        albert = cls()
        if "farm_lm_name" in kwargs:
            albert.name = kwargs["farm_lm_name"]
        else:
            albert.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = AlbertConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            albert.model = AlbertModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            albert.language = albert.model.config.language
        else:
            # Huggingface transformer Style
            albert.model = AlbertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            albert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return albert
Exemple #2
0
    def __init__(self, config):
        super(ModelA, self).__init__(config)
        self.kbert = False

        if self.kbert:
            self.albert = KBERT(config)
        else:
            self.albert = AlbertModel(config)

        self.att_merge = AttentionMerge(config.hidden_size, 1024, 0.1)

        self.scorer = nn.Sequential(nn.Dropout(0.1),
                                    nn.Linear(config.hidden_size, 1))

        self.init_weights()
Exemple #3
0
 def __init__(self, config):
     super(AlbertCrfForNer, self).__init__(config)
     self.albert = AlbertModel(config)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
     self.classifier = nn.Linear(config.hidden_size, config.num_labels)
     self.crf = CRF(num_tags=config.num_labels, batch_first=True)
     self.init_weights()
 def __init__(self, config, num_labels=3, dropout=None):
     super(AlbertForABSA, self).__init__(config)
     self.num_labels = num_labels
     self.albert = AlbertModel(config)
     self.dropout = torch.nn.Dropout(dropout)
     self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
     self.init_weights
Exemple #5
0
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()
Exemple #6
0
 def __init__(self, config):
     super(AlbertSoftmaxForNer, self).__init__(config)
     self.num_labels = config.num_labels
     self.albert = AlbertModel(config)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
     self.classifier = nn.Linear(config.hidden_size, config.num_labels)
     self.loss_type = config.loss_type
     self.init_weights()
    def __init__(self, config, weight=None):
        super(AlbertForSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
        self.weight = weight

        self.init_weights()
    def init_data(self, use_cuda: bool) -> None:
        self.test_device = torch.device('cuda:0') if use_cuda else \
            torch.device('cpu:0')
        if not use_cuda:
            torch.set_num_threads(4)
            turbo_transformers.set_num_threads(4)

        torch.set_grad_enabled(False)
        self.cfg = AlbertConfig(hidden_size=768,
                                num_attention_heads=12,
                                intermediate_size=3072)
        self.torch_model = AlbertModel(self.cfg)

        if torch.cuda.is_available():
            self.torch_model.to(self.test_device)
        self.torch_model.eval()
        self.hidden_size = self.cfg.hidden_size

        self.turbo_model = turbo_transformers.AlbertModel.from_torch(
            self.torch_model)
Exemple #9
0
    def __init__(self, config):
        super(AlbertForCloth, self).__init__(config)
        
        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)

        self.init_weights()
        self.tie_weights()

        self.loss = nn.CrossEntropyLoss(reduction='none')
        self.vocab_size = self.albert.embeddings.word_embeddings.weight.size(0)
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super(JointAlbert, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        self.albert = AlbertModel(config=config)  # Load pretrained bert

        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)

        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
Exemple #11
0
 def __init__(self, config, ):
     super(AlbertSpanForNer, self).__init__(config)
     self.soft_label = config.soft_label
     self.num_labels = config.num_labels
     self.loss_type = config.loss_type
     self.albert = AlbertModel(config)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
     self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
     if self.soft_label:
         self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
     else:
         self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
     self.init_weights()
        def init_data(self, use_cuda: bool) -> None:
            self.test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(4)
                turbo_transformers.set_num_threads(4)

            torch.set_grad_enabled(False)
            self.cfg = AlbertConfig()

            self.torch_model = AlbertModel(self.cfg)
            if torch.cuda.is_available():
                self.torch_model.to(self.test_device)
            self.torch_model.eval()
            self.hidden_size = self.cfg.hidden_size
            self.input_tensor = torch.randint(low=0,
                                              high=self.cfg.vocab_size - 1,
                                              size=(batch_size, seq_length),
                                              device=self.test_device)

            self.turbo_model = turbo_transformers.AlbertModel.from_torch(
                self.torch_model)
Exemple #13
0
    def __init__(self, config, non_interaction_layers=None):
        super(DilAlbert, self).__init__(config)
        self.num_labels = config.num_labels

        self.albert = AlbertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
        
        ## Non Interaction Size
        if non_interaction_layers is not None:
            self.non_interaction_layers = non_interaction_layers
            print(f"Dilalbert: non_interaction_layers: Use of class parameter during initialization. Value {self.non_interaction_layers}")
        elif hasattr(config, 'non_interaction_layers'):
            self.non_interaction_layers = config.non_interaction_layers
            print(f"Dilalbert: non_interaction_layers: Use of config file variable. Value {self.non_interaction_layers}")
        else:
            self.non_interaction_layers = DEFAULT_NON_INTERACTION_LAYERS
            print(f"Dilalbert: non_interaction_layers: Use of default value. Value {self.non_interaction_layers}")
        
        if EVAL_TIME:
            self.count = 0
            self.time_perf = {
                'qst tokens count': 0,
                'ctxt tokens count': 0,
                'split qst ctxt': 0,
                'qst process bert input': 0,
                'ctxt process bert input': 0,
                'qst embed': 0,
                'ctxt embed': 0,
                'qst part A': 0,
                'ctxt part A': 0,
                'process bert input': 0,
                'part B': 0,
                'part C': 0
            }

        self.init_weights()
Exemple #14
0
    def __init__(self, cfg):
        super(DSB_ALBERTModel, self).__init__()
        self.cfg = cfg
        cate_col_size = len(cfg.cate_cols)
        cont_col_size = len(cfg.cont_cols)
        self.cate_emb = nn.Embedding(cfg.total_cate_size,
                                     cfg.emb_size,
                                     padding_idx=0)

        def get_cont_emb():
            return nn.Sequential(nn.Linear(cont_col_size, cfg.hidden_size),
                                 nn.LayerNorm(cfg.hidden_size), nn.ReLU(),
                                 nn.Linear(cfg.hidden_size, cfg.hidden_size))

        self.cont_emb = get_cont_emb()
        self.config = AlbertConfig(
            3,  # not used
            embedding_size=cfg.emb_size * cate_col_size + cfg.hidden_size,
            hidden_size=cfg.emb_size * cate_col_size + cfg.hidden_size,
            num_hidden_layers=cfg.nlayers,
            #num_hidden_groups=1,
            num_attention_heads=cfg.nheads,
            intermediate_size=cfg.hidden_size,
            hidden_dropout_prob=cfg.dropout,
            attention_probs_dropout_prob=cfg.dropout,
            max_position_embeddings=cfg.seq_len,
            type_vocab_size=1,
            #initializer_range=0.02,
            #layer_norm_eps=1e-12,
        )

        self.encoder = AlbertModel(self.config)

        def get_reg():
            return nn.Sequential(
                nn.Linear(cfg.emb_size * cate_col_size + cfg.hidden_size,
                          cfg.hidden_size),
                nn.LayerNorm(cfg.hidden_size),
                nn.Dropout(cfg.dropout),
                nn.ReLU(),
                nn.Linear(cfg.hidden_size, cfg.hidden_size),
                nn.LayerNorm(cfg.hidden_size),
                nn.Dropout(cfg.dropout),
                nn.ReLU(),
                nn.Linear(cfg.hidden_size, cfg.target_size),
            )

        self.reg_layer = get_reg()
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super(JointAlbert, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        self.albert = AlbertModel(config=config)  # Load pretrained bert
        # self.dropout = nn.Dropout(args.dropout_rate)
        # self.lstm = nn.LSTM(config.hidden_size, config.hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
        # self.gru = nn.GRU(config.hidden_size, config.hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
        self.intent_classifier = IntentClassifier(config.hidden_size,
                                                  self.num_intent_labels,
                                                  args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size,
                                              self.num_slot_labels,
                                              args.dropout_rate)

        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
Exemple #16
0
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)

        # For Masked LM
        # The original huggingface implementation, created new output weights via dense layer
        # However the original Albert
        self.predictions_dense = nn.Linear(config.hidden_size,
                                           config.embedding_size)
        self.predictions_activation = ACT2FN[config.hidden_act]
        self.predictions_LayerNorm = nn.LayerNorm(config.embedding_size)
        self.predictions_bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.predictions_decoder = nn.Linear(config.embedding_size,
                                             config.vocab_size)

        self.predictions_decoder.weight = self.albert.embeddings.word_embeddings.weight

        # For sequence order prediction
        self.seq_relationship = AlbertSequenceOrderHead(config)
Exemple #17
0
    def __init__(self, config, args):
        super().__init__(config)
        self.args = args

        if args.bert_model == "albert-base-v2":
            bert = AlbertModel.from_pretrained(args.bert_model)
        elif args.bert_model == "emilyalsentzer/Bio_ClinicalBERT":
            bert = AutoModel.from_pretrained(args.bert_model)
        elif args.bert_model == "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12":
            bert = AutoModel.from_pretrained(args.bert_model)
        elif args.bert_model == "bert-small-scratch":
            config = BertConfig.from_pretrained(
                "google/bert_uncased_L-4_H-512_A-8")
            bert = BertModel(config)
        elif args.bert_model == "bert-base-scratch":
            config = BertConfig.from_pretrained("bert-base-uncased")
            bert = BertModel(config)
        else:
            bert = BertModel.from_pretrained(
                args.bert_model)  # bert-base-uncased, small, tiny

        self.txt_embeddings = bert.embeddings
        self.img_embeddings = ImageBertEmbeddings(args, self.txt_embeddings)

        if args.img_encoder == 'ViT':
            img_size = args.img_size
            patch_sz = 32 if img_size == 512 else 16
            self.img_encoder = Img_patch_embedding(image_size=img_size,
                                                   patch_size=patch_sz,
                                                   dim=2048)
        else:
            self.img_encoder = ImageEncoder_cnn(args)
            for p in self.img_encoder.parameters():
                p.requires_grad = False
            for c in list(self.img_encoder.children())[5:]:
                for p in c.parameters():
                    p.requires_grad = True

        self.encoder = bert.encoder
        self.pooler = bert.pooler
Exemple #18
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 probe_type: str = None,
                 layer_freeze_regexes: List[str] = None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._pretrained_model = pretrained_model
        if 'roberta' in pretrained_model:
            self._padding_value = 1  # The index of the RoBERTa padding token
            self._transformer_model = RobertaModel.from_pretrained(
                pretrained_model)
            self._dropout = torch.nn.Dropout(
                self._transformer_model.config.hidden_dropout_prob)
        elif 'xlnet' in pretrained_model:
            self._padding_value = 5  # The index of the XLNet padding token
            self._transformer_model = XLNetModel.from_pretrained(
                pretrained_model)
            self.sequence_summary = SequenceSummary(
                self._transformer_model.config)
        elif 'albert' in pretrained_model:
            self._transformer_model = AlbertModel.from_pretrained(
                pretrained_model)
            self._padding_value = 0  # The index of the BERT padding token
            self._dropout = torch.nn.Dropout(
                self._transformer_model.config.hidden_dropout_prob)
        elif 'bert' in pretrained_model:
            self._transformer_model = BertModel.from_pretrained(
                pretrained_model)
            self._padding_value = 0  # The index of the BERT padding token
            self._dropout = torch.nn.Dropout(
                self._transformer_model.config.hidden_dropout_prob)
        else:
            assert (ValueError)

        if probe_type == 'MLP':
            layer_freeze_regexes = ["embeddings", "encoder"]

        for name, param in self._transformer_model.named_parameters():
            if layer_freeze_regexes and requires_grad:
                grad = not any(
                    [bool(re.search(r, name)) for r in layer_freeze_regexes])
            else:
                grad = requires_grad
            if grad:
                param.requires_grad = True
            else:
                param.requires_grad = False

        transformer_config = self._transformer_model.config
        transformer_config.num_labels = 1
        self._output_dim = self._transformer_model.config.hidden_size

        # unifing all model classification layer
        self._classifier = Linear(self._output_dim, 1)
        self._classifier.weight.data.normal_(mean=0.0, std=0.02)
        self._classifier.bias.data.zero_()

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        self._debug = 2
Exemple #19
0
class ModelA(AlbertPreTrainedModel):
    """
    AlBert-AttentionMerge-Classifier

    1. self.forward(input_ids, attention_mask, token_type_ids, label)
    2. self.predict(input_ids, attention_mask, token_type_ids)
    """
    def __init__(self, config):
        super(ModelA, self).__init__(config)
        self.kbert = False

        if self.kbert:
            self.albert = KBERT(config)
        else:
            self.albert = AlbertModel(config)

        self.att_merge = AttentionMerge(config.hidden_size, 1024, 0.1)

        self.scorer = nn.Sequential(nn.Dropout(0.1),
                                    nn.Linear(config.hidden_size, 1))

        self.init_weights()

    def score(self, h1, h2, h3, h4, h5):
        """
        h1, h2: [B, H] => logits: [B, 2]
        """
        logits1 = self.scorer(h1)
        logits2 = self.scorer(h2)
        logits3 = self.scorer(h3)
        logits4 = self.scorer(h4)
        logits5 = self.scorer(h5)
        logits = torch.cat((logits1, logits2, logits3, logits4, logits5),
                           dim=1)
        return logits

    def forward(self, idx, input_ids, attention_mask, token_type_ids, labels):
        """
        input_ids: [B, 2, L]
        labels: [B, ]
        """
        # logits: [B, 2]
        logits = self._forward(idx, input_ids, attention_mask, token_type_ids)
        loss = F.cross_entropy(logits, labels)

        with torch.no_grad():
            logits = F.softmax(logits, dim=1)
            predicts = torch.argmax(logits, dim=1)
            right_num = torch.sum(predicts == labels)

        return loss, right_num, self._to_tensor(idx.size(0), idx.device)

    def _forward(self, idx, input_ids, attention_mask, token_type_ids):
        # [B, 2, L] => [B*2, L]
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))

        outputs = self.albert(input_ids=flat_input_ids,
                              attention_mask=flat_attention_mask,
                              token_type_ids=flat_token_type_ids)

        if self.kbert:
            flat_attention_mask = self.albert.get_attention_mask()

        # outputs[0]: [B*2, L, H] => [B*2, H]
        h12 = self.att_merge(outputs[0], flat_attention_mask)

        # [B*2, H] => [B*2, 1] => [B, 2]
        logits = self.scorer(h12).view(-1, 5)

        return logits

    def predict(self, idx, input_ids, attention_mask, token_type_ids):
        """
        return: [B, 2]
        """
        return self._forward(idx, input_ids, attention_mask, token_type_ids)

    def _to_tensor(self, it, device):
        return torch.tensor(it, device=device, dtype=torch.float)
Exemple #20
0
class DilAlbert(AlbertPreTrainedModel):
    def __init__(self, config, non_interaction_layers=None):
        super(DilAlbert, self).__init__(config)
        self.num_labels = config.num_labels

        self.albert = AlbertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
        
        ## Non Interaction Size
        if non_interaction_layers is not None:
            self.non_interaction_layers = non_interaction_layers
            print(f"Dilalbert: non_interaction_layers: Use of class parameter during initialization. Value {self.non_interaction_layers}")
        elif hasattr(config, 'non_interaction_layers'):
            self.non_interaction_layers = config.non_interaction_layers
            print(f"Dilalbert: non_interaction_layers: Use of config file variable. Value {self.non_interaction_layers}")
        else:
            self.non_interaction_layers = DEFAULT_NON_INTERACTION_LAYERS
            print(f"Dilalbert: non_interaction_layers: Use of default value. Value {self.non_interaction_layers}")
        
        if EVAL_TIME:
            self.count = 0
            self.time_perf = {
                'qst tokens count': 0,
                'ctxt tokens count': 0,
                'split qst ctxt': 0,
                'qst process bert input': 0,
                'ctxt process bert input': 0,
                'qst embed': 0,
                'ctxt embed': 0,
                'qst part A': 0,
                'ctxt part A': 0,
                'process bert input': 0,
                'part B': 0,
                'part C': 0
            }

        self.init_weights()
        
        #if SAME_LAYER_GROUP_A_AND_B:
            #copy.deepcopy(model) copy layer group

    
    def split_question_context(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None
    ):
        batch_size = len(input_ids)
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        question_inputs = {'input_ids': [], 'attention_mask': [], 'token_type_ids': [], 'position_ids': [], 'inputs_embeds': []}
        context_inputs = {'input_ids': [], 'attention_mask': [], 'token_type_ids': [], 'position_ids': [], 'inputs_embeds': []}
        
        seq_len = len(input_ids[0])
        split_idxs = []
        for i in range(batch_size):  ## for every example in batch
            idx = (input_ids[i] == 3).nonzero()[0][0] ## Here replace 3 with tokenizer.sep_token_id
            split_idxs.append(idx)
        max_qst_len = max(split_idxs)+1
        max_ctxt_len = len(input_ids[0]) - min(split_idxs) -1
        
        paddings = []
        for i in range(batch_size):  ## for every example in batch
            split_idx = split_idxs[i]
            qst_padding = max_qst_len - split_idx - 1
            ctxt_padding = max_ctxt_len-(seq_len-split_idx-1)
            paddings.append(qst_padding)
            
            if input_ids is not None:
                question_inputs['input_ids'].append(input_ids[i][:split_idx+1].tolist() + [0]*qst_padding)
                context_inputs['input_ids'].append(input_ids[i][split_idx+1:].tolist() + [0]*ctxt_padding)
            if attention_mask is not None:
                question_inputs['attention_mask'].append(attention_mask[i][:split_idx+1].tolist() + [0]*qst_padding)
                context_inputs['attention_mask'].append(attention_mask[i][split_idx+1:].tolist() + [0]*ctxt_padding)
            if token_type_ids is not None:
                question_inputs['token_type_ids'].append(token_type_ids[i][:split_idx+1].tolist() + [0]*qst_padding)
                context_inputs['token_type_ids'].append(token_type_ids[i][split_idx+1:].tolist() + [1]*ctxt_padding)
            ## these embeddings are disabled and the BERT encoder uses the default ones
            if False:
                if position_ids is not None:
                    question_inputs['position_ids'].append(position_ids[i][:split_idx+1].tolist())
                    context_inputs['position_ids'].append(position_ids[i][split_idx+1:].tolist())
                if inputs_embeds is not None:
                    question_inputs['inputs_embeds'].append(inputs_embeds[i][:split_idx+1].tolist())
                    context_inputs['inputs_embeds'].append(inputs_embeds[i][split_idx+1:].tolist())

        question_inputs['input_ids'] = torch.tensor(question_inputs['input_ids'], device=device)#, requires_grad=True)
        question_inputs['token_type_ids'] = torch.tensor(question_inputs['token_type_ids'], device=device)
        question_inputs['position_ids'] = None # torch.tensor(question_inputs['position_ids'], device=device)
        question_inputs['attention_mask'] = torch.tensor(question_inputs['attention_mask'], device=device)
        question_inputs['inputs_embeds'] = None
        context_inputs['input_ids'] = torch.tensor(context_inputs['input_ids'], device=device)#, requires_grad=True)
        context_inputs['token_type_ids'] = torch.tensor(context_inputs['token_type_ids'], device=device)
        context_inputs['position_ids'] = None # torch.tensor(context_inputs['position_ids'], device=device)
        context_inputs['attention_mask'] = torch.tensor(context_inputs['attention_mask'], device=device)
        context_inputs['inputs_embeds'] = None

        return question_inputs, context_inputs, split_idxs, paddings
    
    
    def process_albert_input(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None
    ):
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.albert.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.albert.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.num_hidden_layers
        
        return {
            'input_ids': input_ids,
            'attention_mask': extended_attention_mask,
            'position_ids': position_ids,
            'token_type_ids': token_type_ids,
            'head_mask': head_mask,
            'inputs_embeds': inputs_embeds
        }

    def embed(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, **kwargs):
        """Return BERT embeddings."""
        return self.albert.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )


    def forward_encoder(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        first_part=True
    ):
        
        if first_part:
            hidden_states = self.albert.encoder.embedding_hidden_mapping_in(hidden_states)
            num_layers = self.non_interaction_layers
        else:
            num_layers = self.albert.encoder.config.num_hidden_layers - self.non_interaction_layers
        
        
        all_attentions = ()

        if self.albert.encoder.output_hidden_states:
            all_hidden_states = (hidden_states,)
      
        for i in range(num_layers):                        
            
            # Number of layers in a hidden group
            layers_per_group = int(self.albert.encoder.config.num_hidden_layers / self.albert.encoder.config.num_hidden_groups)

            # Index of the hidden group
            group_idx = int(i / (self.albert.encoder.config.num_hidden_layers / self.albert.encoder.config.num_hidden_groups))
            
            layer_group_output = self.forward_albert_layer_group(group_idx, hidden_states, attention_mask=attention_mask, head_mask=head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],first_part=first_part)
            
            hidden_states = layer_group_output[0]

            if self.albert.encoder.output_attentions:
                all_attentions = all_attentions + layer_group_output[-1]

            if self.albert.encoder.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.albert.encoder.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.albert.encoder.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)        
    
    def forward_albert_layer_group(self,group_idx, hidden_states, attention_mask=None, head_mask=None,first_part=True):
        
        
        # Not modified in the end, because its not supposed to be here... 
        
        layer_hidden_states = ()
        layer_attentions = ()
        

        for layer_index, albert_layer in enumerate(self.albert.encoder.albert_layer_groups[group_idx].albert_layers):               
            
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
            hidden_states = layer_output[0]

            if self.albert.encoder.albert_layer_groups[group_idx].output_attentions:
                layer_attentions = layer_attentions + (layer_output[1],)

            if self.albert.encoder.albert_layer_groups[group_idx].output_hidden_states:
                layer_hidden_states = layer_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.albert.encoder.albert_layer_groups[group_idx].output_hidden_states:
            outputs = outputs + (layer_hidden_states,)
        if self.albert.encoder.albert_layer_groups[group_idx].output_attentions:
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
        
    
    def forward_albert(
        self,
        embeddings,
        attention_mask=None,
        head_mask=None,
        first_part=True,
        **kwargs
    ):
        encoder_outputs = self.forward_encoder(
            embeddings,
            attention_mask=attention_mask,
            head_mask=head_mask,
            first_part=first_part
        )
        
        sequence_output = encoder_outputs[0]

        pooled_output = self.albert.pooler_activation(self.albert.pooler(sequence_output[:, 0]))

        outputs = (sequence_output, pooled_output) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        return outputs
        
    
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None
    ):
        torch.set_printoptions(profile="full")
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        t0, t1, t2, t3, t4, t5, t6, t7 = [None]*8
        
        
        if EVAL_TIME: t0 = perf_counter()
        question_inputs, context_inputs, split_idxs, paddings = self.split_question_context(input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, start_positions, end_positions)
        if EVAL_TIME: self.time_perf['split qst ctxt'] += (perf_counter() - t0)

        
        if EVAL_TIME: t1 = perf_counter()
        question_inputs = self.process_albert_input(**question_inputs)
        if EVAL_TIME: t2 = perf_counter()
        embeddings = self.embed(**question_inputs)
        
        
        if IGNORE_CLS_PART_A:
            cls_copy = embeddings[:,0]

        if EVAL_TIME: t3 = perf_counter()

        qst_out = self.forward_albert(
            **question_inputs,
            embeddings=embeddings,
            first_part=True
        )
        
        if EVAL_TIME: t4 = perf_counter()
        context_inputs = self.process_albert_input(**context_inputs)
        if EVAL_TIME: t5 = perf_counter()
        embeddings = self.embed(**context_inputs)
        if EVAL_TIME: t6 = perf_counter()
        ctxt_out = self.forward_albert(
            **context_inputs,
            embeddings=embeddings,
            first_part=True
        )
        
        if EVAL_TIME:
            t7 = perf_counter()
            self.time_perf['qst process bert input'] += (t2-t1)
            self.time_perf['qst embed'] += (t3-t2)
            self.time_perf['qst part A'] += (t4-t3)
            self.time_perf['ctxt process bert input'] += (t5-t4)
            self.time_perf['ctxt embed'] += (t6-t5)
            self.time_perf['ctxt part A'] += (t7-t6)
            t1 = perf_counter()
        
        
        outs = torch.cat((qst_out[0], ctxt_out[0]), 1)
        clone = outs.clone()
        for i, (idx, pad) in enumerate(zip(split_idxs, paddings)):
            if pad == 0: continue
            outs[i, idx+1:-pad] = clone[i, idx+1+pad:]
        hidden_states = outs = outs[:,:MAX_SEQ_LENGTH]
        #print(MAX_SEQ_LENGTH)
        #print(hidden_states.shape)
        #attention_mask = attention_mask[:, :383]
        
        
        if IGNORE_CLS_PART_A:
            hidden_states[:,0] = cls_copy
        
        albert_input = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'head_mask': head_mask,
            'inputs_embeds':inputs_embeds
        }
        
        albert_input = self.process_albert_input(**albert_input)
        #bert_input['attention_mask'] = torch.cat((question_inputs['attention_mask'], context_inputs['attention_mask']), 3)[:, :, :, :MAX_SEQ_LENGTH]
        
        if UPDATE_CLS:
            hidden_states[:,0] = hidden_states[:,1:].mean(dim=1)
        if UPDATE_UMBEDDINGS:
            input_shape = input_ids.size()
            seq_length = input_shape[1]
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            if position_ids is None:
                position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
                position_ids = position_ids.unsqueeze(0).expand(input_shape)
            hidden_states = (hidden_states +
                            self.bert.embeddings.position_embeddings(position_ids) +
                            self.bert.embeddings.token_type_embeddings(token_type_ids))
            hidden_states = self.bert.embeddings.LayerNorm(hidden_states)
            hidden_states = self.bert.embeddings.dropout(hidden_states)

        if EVAL_TIME: t2 = perf_counter()
        #print(hidden_states.shape)
        #input()
        outputs = self.forward_albert(
            **albert_input,
            embeddings=hidden_states,
            first_part=False
        )
        hidden_states = outputs[0]
        
        if UPDATE_FINAL_CLS:
            hidden_states[:,0] = hidden_states[:,1:].mean(dim=1)
        
        
        if EVAL_TIME: t3 = perf_counter()
        logits = self.qa_outputs(hidden_states)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        if EVAL_TIME:
            t4 = perf_counter()
            self.time_perf['process bert input'] += (t2-t1)
            self.time_perf['part B'] += (t3-t2)
            self.time_perf['part C'] += (t4-t3)

            self.count += 1
            ## print every 50 batches
            if self.count%50 == 0 or self.count == 1036:  # 12430
                for k, v in self.time_perf.items(): print(f"{k}: {v}")
        
        if not TRAINING and False:
            start_logits = start_logits.tolist()
            end_logits = end_logits.tolist()
            fin = MAX_SEQ_LENGTH - max(paddings)
            for i in range(len(start_logits)):
                start_logits[i] = start_logits[i][paddings[i]:paddings[i]+fin]
                end_logits[i] = end_logits[i][paddings[i]:paddings[i]+fin]
            start_logits = torch.tensor(start_logits) #, device=device)
            end_logits = torch.tensor(end_logits) #, device=device)
        
        
        outputs = (start_logits, end_logits,) + outputs[2:] #+ (paddings,)
        if start_positions is not None and end_positions is not None:
            #paddings = torch.tensor(paddings, device=device)
            #start_positions += paddings
            #end_positions += paddings
            
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs


        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)


    #non interaction layers
    def process_A(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None
    ):
        torch.set_printoptions(profile="full")
        albert_input = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'head_mask': head_mask,
            'inputs_embeds':inputs_embeds
        }
        
        albert_input = self.process_albert_input(**albert_input)
        embeddings = self.embed(**albert_input)
        
        albert_out = self.forward_albert(
            **albert_input,
            embeddings=embeddings,
            first_part=True
        )
        
        return albert_out[0]
    
    # interaction layers
    def process_B(
        self,
        qst_embeddings,
        ctxt_embeddings,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None
    ):
        
        outs = torch.cat((qst_embeddings, ctxt_embeddings), 1)
        hidden_states = outs = outs[:,:MAX_SEQ_LENGTH]

        albert_input = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'head_mask': head_mask,
            'inputs_embeds':inputs_embeds
        }
        
        albert_input = self.process_albert_input(**albert_input)
        
        outputs = self.forward_albert(
            **albert_input,
            embeddings=hidden_states,
            first_part=False
        )
        hidden_states = outputs[0]
        
        
        logits = self.qa_outputs(hidden_states)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
              
        
        outputs = (start_logits, end_logits,) + outputs[2:] #+ (paddings,)
        if start_positions is not None and end_positions is not None:
            #paddings = torch.tensor(paddings, device=device)
            #start_positions += paddings
            #end_positions += paddings
            
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
        
    class TestAlbertModel(unittest.TestCase):
        def init_data(self, use_cuda: bool) -> None:
            self.test_device = torch.device('cuda:0') if use_cuda else \
                torch.device('cpu:0')
            if not use_cuda:
                torch.set_num_threads(4)
                turbo_transformers.set_num_threads(4)

            torch.set_grad_enabled(False)
            self.cfg = AlbertConfig()

            self.torch_model = AlbertModel(self.cfg)
            if torch.cuda.is_available():
                self.torch_model.to(self.test_device)
            self.torch_model.eval()
            self.hidden_size = self.cfg.hidden_size
            self.input_tensor = torch.randint(low=0,
                                              high=self.cfg.vocab_size - 1,
                                              size=(batch_size, seq_length),
                                              device=self.test_device)

            self.turbo_model = turbo_transformers.AlbertModel.from_torch(
                self.torch_model)

        def check_torch_and_turbo(self, use_cuda):
            self.init_data(use_cuda=use_cuda)
            device = "GPU" if use_cuda else "CPU"
            num_iter = 1
            turbo_model = lambda: self.turbo_model(
                self.input_tensor, attention_mask=None, head_mask=None)
            turbo_result, turbo_qps, turbo_time = \
                test_helper.run_model(turbo_model, use_cuda, num_iter)

            print(
                f"AlbertLayer \"({batch_size},{seq_length:03})\" ",
                f"{device} TurboTransform QPS,  {turbo_qps}, time, {turbo_time}"
            )
            torch_model = lambda: self.torch_model(input_ids=self.input_tensor,
                                                   attention_mask=None,
                                                   head_mask=None)
            with turbo_transformers.pref_guard("albert_perf") as perf:
                torch_result, torch_qps, torch_time = \
                    test_helper.run_model(torch_model, use_cuda, num_iter)

            print(f"AlbertModel \"({batch_size},{seq_length:03})\" ",
                  f"{device} Torch QPS,  {torch_qps}, time, {torch_time}")

            # print(turbo_result[-1])
            # print(turbo_result, torch_result[0])
            # TODO(jiaruifang) Error is too high. Does tensor core introduce more differences?
            tolerate_error = 1e-2
            self.assertTrue(
                torch.max(torch.abs(torch_result[0] -
                                    turbo_result[0])) < tolerate_error)

            with open("albert_model_res.txt", "a") as fh:
                fh.write(
                    f"\"({batch_size},{seq_length:03})\", {torch_qps}, {torch_qps}\n"
                )

        def test_layer(self):
            self.check_torch_and_turbo(use_cuda=False)
            if torch.cuda.is_available() and \
                turbo_transformers.config.is_compiled_with_cuda():
                self.check_torch_and_turbo(use_cuda=True)
Exemple #22
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 transformer_weights_model: str = None,
                 num_labels: int = 2,
                 predictions_file=None,
                 layer_freeze_regexes: List[str] = None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._predictions = []

        self._pretrained_model = pretrained_model

        if 't5' in pretrained_model:
            self._padding_value = 1  # The index of the RoBERTa padding token
            if transformer_weights_model:  # Override for RoBERTa only for now
                logging.info(f"Loading Transformer weights model from {transformer_weights_model}")
                transformer_model_loaded = load_archive(transformer_weights_model)
                self._transformer_model = transformer_model_loaded.model._transformer_model
            else:
                self._transformer_model = T5Model.from_pretrained(pretrained_model)
            self._dropout = torch.nn.Dropout(self._transformer_model.config.hidden_dropout_prob)
        if 'roberta' in pretrained_model:
            self._padding_value = 1  # The index of the RoBERTa padding token
            if transformer_weights_model:  # Override for RoBERTa only for now
                logging.info(f"Loading Transformer weights model from {transformer_weights_model}")
                transformer_model_loaded = load_archive(transformer_weights_model)
                self._transformer_model = transformer_model_loaded.model._transformer_model
            else:
                self._transformer_model = RobertaModel.from_pretrained(pretrained_model)
            self._dropout = torch.nn.Dropout(self._transformer_model.config.hidden_dropout_prob)
        elif 'xlnet' in pretrained_model:
            self._padding_value = 5  # The index of the XLNet padding token
            self._transformer_model = XLNetModel.from_pretrained(pretrained_model)
            self.sequence_summary = SequenceSummary(self._transformer_model.config)
        elif 'albert' in pretrained_model:
            self._transformer_model = AlbertModel.from_pretrained(pretrained_model)
            self._padding_value = 0  # The index of the BERT padding token
            self._dropout = torch.nn.Dropout(self._transformer_model.config.hidden_dropout_prob)
        elif 'bert' in pretrained_model:
            self._transformer_model = BertModel.from_pretrained(pretrained_model)
            self._padding_value = 0  # The index of the BERT padding token
            self._dropout = torch.nn.Dropout(self._transformer_model.config.hidden_dropout_prob)
        else:
            assert (ValueError)

        for name, param in self._transformer_model.named_parameters():
            if layer_freeze_regexes and requires_grad:
                grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes])
            else:
                grad = requires_grad
            if grad:
                param.requires_grad = True
            else:
                param.requires_grad = False

        transformer_config = self._transformer_model.config
        transformer_config.num_labels = num_labels
        self._output_dim = self._transformer_model.config.hidden_size

        # unifing all model classification layer
        self._classifier = Linear(self._output_dim, num_labels)
        self._classifier.weight.data.normal_(mean=0.0, std=0.02)
        self._classifier.bias.data.zero_()

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        self._debug = -1
Exemple #23
0
def wsd(
    model_name='bert-base-uncased',  #ensemble-distil-1-albert-1 / albert-xxlarge-v2 / bert-base-uncased
    classifier_input='token-embedding-last-1-layers',  # token-embedding-last-layer / token-embedding-last-n-layers
    classifier_hidden_layers=[],
    reduce_options=True,
    freeze_base_model=True,
    max_len=512,
    batch_size=32,
    test=False,
    lr=5e-5,
    eps=1e-8,
    n_epochs=50,
    cls_token=False,  # If true, the cls token is used instead of the relevant-word token
    cache_embeddings=False,  # If true, the embeddings from the base model are saved to disk so that they only need to be computed once
    save_classifier=True  # If true, the classifier part of the network is saved after each epoch, and the training is automatically resumed from this saved network if it exists
):
    train_path = "wsd_train.txt"
    test_path = "wsd_test_blind.txt"
    n_classes = 222
    device = 'cuda'

    import __main__ as main
    print("Script: " + os.path.basename(main.__file__))

    print("Loading base model %s..." % model_name)
    if model_name.startswith('ensemble-distil-'):
        last_n_distil = int(model_name.replace('ensemble-distil-', "")[0])
        last_n_albert = int(model_name[-1])
        from transformers import AlbertTokenizer
        from transformers.modeling_albert import AlbertModel
        base_model = AlbertModel.from_pretrained('albert-xxlarge-v2',
                                                 output_hidden_states=True,
                                                 output_attentions=False)
        tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
        print(
            "Ensemble model with DistilBert last %d layers and Albert last %d layers"
            % (last_n_distil, last_n_albert))
    elif model_name.startswith('distilbert'):
        tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        base_model = DistilBertModel.from_pretrained(model_name,
                                                     num_labels=n_classes,
                                                     output_hidden_states=True,
                                                     output_attentions=False)
    elif model_name.startswith('bert'):
        from transformers import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained(model_name)
        base_model = BertModel.from_pretrained(model_name,
                                               num_labels=n_classes,
                                               output_hidden_states=True,
                                               output_attentions=False)
    elif model_name.startswith('albert'):
        from transformers import AlbertTokenizer
        from transformers.modeling_albert import AlbertModel
        tokenizer = AlbertTokenizer.from_pretrained(model_name)
        base_model = AlbertModel.from_pretrained(model_name,
                                                 output_hidden_states=True,
                                                 output_attentions=False)

    use_n_last_layers = 1
    if classifier_input == 'token-embedding-last-layer':
        use_n_last_layers = 1
    elif classifier_input.startswith(
            'token-embedding-last-') and classifier_input.endswith('-layers'):
        use_n_last_layers = int(
            classifier_input.replace('token-embedding-last-',
                                     "").replace('-layers', ""))
    else:
        raise ValueError("Invalid classifier_input argument")
    print("Using the last %d layers" % use_n_last_layers)

    def tokenize(str):
        return tokenizer.tokenize(str)[:max_len - 2]

    SENSE = LabelField(is_target=True)
    LEMMA = LabelField()
    TOKEN_POS = LabelField(use_vocab=False)
    TEXT = Field(tokenize=tokenize,
                 pad_token=tokenizer.pad_token,
                 init_token=tokenizer.cls_token,
                 eos_token=tokenizer.sep_token)
    EXAMPLE_ID = LabelField(use_vocab=False)
    fields = [('sense', SENSE), ('lemma', LEMMA), ('token_pos', TOKEN_POS),
              ('text', TEXT), ('example_id', EXAMPLE_ID)]

    def read_data(corpus_file, fields, max_len=None):
        train_id_start = 0
        test_id_start = 76049  # let the ids for the test examples start after the training example indices
        if corpus_file == "wsd_test_blind.txt":
            print("Loading test data...")
            id_start = test_id_start
        else:
            print("Loading train/val data...")
            id_start = train_id_start
        with open(corpus_file, encoding='utf-8') as f:
            examples = []
            for i, line in enumerate(f):
                sense, lemma, word_position, text = line.split('\t')
                # We need to convert from the word position to the token position
                words = text.split()
                pre_word = " ".join(words[:int(word_position)])
                pre_word_tokenized = tokenizer.tokenize(pre_word)
                token_position = len(
                    pre_word_tokenized
                ) + 1  # taking into account the later addition of the start token
                example_id = id_start + i
                if max_len is None or token_position < max_len - 1:  # ignore examples where the relevant token is cut off due to max_len
                    if cls_token:
                        token_position = 0
                    examples.append(
                        Example.fromlist(
                            [sense, lemma, token_position, text, example_id],
                            fields))
                else:
                    print(
                        "Example %d is skipped because the relevant token was cut off (token pos = %d)"
                        % (example_id, token_position))
                    print(text)
        return Dataset(examples, fields)

    dataset = read_data(train_path, fields, max_len)
    random.seed(0)
    trn, vld = dataset.split(0.7, stratified=True, strata_field='sense')

    TEXT.build_vocab([])
    if model_name.startswith('albert') or model_name.startswith(
            'ensemble-distil-'):

        class Mapping:
            def __init__(self, fn):
                self.fn = fn

            def __getitem__(self, item):
                return self.fn(item)

        TEXT.vocab.stoi = Mapping(tokenizer.sp_model.PieceToId)
        TEXT.vocab.itos = Mapping(tokenizer.sp_model.IdToPiece)
    else:
        TEXT.vocab.stoi = tokenizer.vocab
        TEXT.vocab.itos = list(tokenizer.vocab)
    SENSE.build_vocab(trn)
    LEMMA.build_vocab(trn)

    trn_iter = BucketIterator(trn,
                              device=device,
                              batch_size=batch_size,
                              sort_key=lambda x: len(x.text),
                              repeat=False,
                              train=True,
                              sort=True)
    vld_iter = BucketIterator(vld,
                              device=device,
                              batch_size=batch_size,
                              sort_key=lambda x: len(x.text),
                              repeat=False,
                              train=False,
                              sort=True)

    if freeze_base_model:
        for mat in base_model.parameters():
            mat.requires_grad = False  # Freeze Bert model so that we only train the classifier on top

    if reduce_options:
        lemma_mask = defaultdict(
            lambda: torch.zeros(len(SENSE.vocab), device=device))
        for example in trn:
            lemma = LEMMA.vocab.stoi[example.lemma]
            sense = SENSE.vocab.stoi[example.sense]
            lemma_mask[lemma][sense] = 1
        lemma_mask = dict(lemma_mask)

        def mask(
            batch_logits, batch_lemmas
        ):  # Masks out the senses that do not belong to the specified lemma
            for batch_i in range(len(batch_logits)):
                lemma = batch_lemmas[batch_i].item()
                batch_logits[batch_i, :] *= lemma_mask[lemma]
            return batch_logits
    else:

        def mask(batch_logits, batch_lemmas):
            return batch_logits

    experiment_name = model_name + " " + (
        classifier_input if not model_name.startswith('ensemble-distil-') else
        "") + " " + str(classifier_hidden_layers) + " (" + (
            " cls_token" if cls_token else
            "") + (" reduce_options" if reduce_options else "") + (
                " freeze_base_model" if freeze_base_model else ""
            ) + "  ) " + "max_len=" + str(max_len) + " batch_size=" + str(
                batch_size) + " lr=" + str(lr) + " eps=" + str(eps) + (
                    " cache_embeddings" if cache_embeddings else "")

    if model_name.startswith('ensemble-distil-'):
        model = WSDEnsembleModel(last_n_distil, last_n_albert, n_classes, mask,
                                 classifier_hidden_layers)
    else:
        model = WSDModel(base_model, n_classes, mask, use_n_last_layers,
                         model_name, classifier_hidden_layers,
                         cache_embeddings)
    history = None
    #if save_classifier:
    #    if model.load_classifier(experiment_name):
    #        # Existing saved model loaded
    #        # Also load the corresponding training history
    #        history = read_dict_file("results/"+experiment_name+".txt")

    model.cuda()

    print("Starting experiment  " + experiment_name)
    if test:
        tst = read_data(test_path, fields, max_len=512)
        tst_iter = Iterator(tst,
                            device=device,
                            batch_size=batch_size,
                            sort=False,
                            sort_within_batch=False,
                            repeat=False,
                            train=False)
        batch_predictions = []
        for batch in tst_iter:
            print('.', end='')
            sys.stdout.flush()
            text = batch.text.t()
            with torch.no_grad():
                outputs = model(text,
                                token_positions=batch.token_pos,
                                lemmas=batch.lemma,
                                example_ids=batch.example_id)
                scores = outputs[-1]
            batch_predictions.append(scores.argmax(dim=1))
        batch_preds = torch.cat(batch_predictions, 0).tolist()
        predicted_senses = [SENSE.vocab.itos(pred) for pred in batch_preds]
        with open("test_predictions/" + experiment_name + ".txt", "w") as out:
            out.write("\n".join(predicted_senses))
    else:
        no_decay = ['bias', 'LayerNorm.weight']
        decay = 0.01
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)

        def save_results(history):
            with open("results/" + experiment_name + ".txt", "w") as out:
                out.write(str(history))
            if save_classifier:
                if len(history['val_acc']) < 2 or history['val_acc'][-1] > max(
                        history['val_acc'][:-1]):
                    model.save_classifier(experiment_name, best=True)
                else:
                    model.save_classifier(experiment_name, best=False)

        train(model, optimizer, trn_iter, vld_iter, n_epochs, save_results,
              history)