class BertFine(BertPreTrainedModel): def __init__(self, bertConfig, num_classes): super(BertFine, self).__init__(bertConfig) self.bert = BertModel(bertConfig) # bert模型 self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob) self.classifier = nn.Linear(in_features=bertConfig.hidden_size, out_features=num_classes) self.apply(self.init_bert_weights) # 默认情况下,bert encoder模型所有的参数都是参与训练的,32的batch_size大概8.7G显存 # 可以通过以下设置为将其设为不训练,只将classifier这一层进行反响传播,32的batch_size大概显存1.1G self.unfreeze_bert_encoder() def freeze_bert_encoder(self): for p in self.bert.parameters(): p.requires_grad = False def unfreeze_bert_encoder(self): for p in self.bert.parameters(): p.requires_grad = True def forward(self, input_ids, token_type_ids, attention_mask, label_ids=None, output_all_encoded_layers=False): _, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=output_all_encoded_layers) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits
class BertForSiameseClassification(BertPreTrainedModel): def __init__(self, config): super(BertForSiameseClassification, self).__init__(config) self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 2) self.apply(self.init_bert_weights) self.avg_vec = AvgVec() def forward(self, input_ids_1, input_mask_1, input_ids_2, input_mask_2): self.bert.eval() encoder_layer_1, pooled_output_1 = self.bert( input_ids_1, token_type_ids=None, attention_mask=input_mask_1) encoder_layer_2, pooled_output_2 = self.bert( input_ids_2, token_type_ids=None, attention_mask=input_mask_2) out1 = self.avg_vec(encoder_layer_1, input_mask_1) out2 = self.avg_vec(encoder_layer_2, input_mask_2) out_norm = diff = torch.abs(out1 - out2) logit = self.classifier(out_norm) softmax = F.softmax(logit, dim=1) return logit, softmax def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): def __init__(self, config, num_labels): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: #loss_fct = BCEWithLogitsLoss() #loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss else: return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config, num_labels=2): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.hidden_size = config.hidden_size self.mem_size = 512 self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.att = DocAttNet(sent_hidden_size=config.hidden_size, doc_hidden_size = self.mem_size, num_classes = num_labels) self.classifier = torch.nn.Linear(self.mem_size *2, num_labels) self.classifier2 = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward2(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_id, token_type_id, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier2(pooled_output) return logits def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, long_doc=True): #import pdb; pdb.set_trace() if long_doc: #self.freeze_bert_encoder() zs = [] for i in range(input_ids.shape[1]): _, pooled_output = self.bert(input_ids[:,i], token_type_ids[:,i], attention_mask[:,i], output_all_encoded_layers=False) #pooled_output = self.dropout(pooled_output) zs.append(pooled_output.detach()) mem = torch.zeros(2, input_ids.shape[0], self.mem_size).cuda() attention_output, word_attn_norm = self.att( torch.stack(zs, 0), mem) attention_output = self.dropout(attention_output) logits = self.classifier(attention_output) return logits, word_attn_norm else: #self.unfreeze_bert_encoder() _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier2(pooled_output) return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): def __init__(self, config, num_labels): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) @property def device(self) -> torch.device: return self.classifier.weight.device def forward(self, input_ids: torch.Tensor, token_type_ids=None, attention_mask: Optional[torch.Tensor] = None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) return loss else: return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True @classmethod def load(cls, modelpath, config, num_labels): print('loading model from [%s]' % modelpath, file=sys.stderr) model = cls(config, num_labels) model.load_state_dict(torch.load(modelpath)) return model
class BertFine(BertPreTrainedModel): def __init__(self,bertConfig,num_classes): super(BertFine ,self).__init__(bertConfig) self.bert = BertModel(bertConfig) self.dropout = nn.Dropout(bertConfig.hidden_dropout_prob) n = 1 if config['feature-based'] == 'Finetune_All': n = bertConfig.num_hidden_layers elif config['feature-based'] == 'Second_to_Last': n = bertConfig.num_hidden_layers-1 elif config['feature-based'] == 'Concat_Last_Four': n = 4 self.pooler = BertFinalPooler(bertConfig.hidden_size, n) self.classifier = nn.Linear(in_features=bertConfig.hidden_size*n, out_features=num_classes) self.apply(self.init_bert_weights) self.unfreeze_bert_encoder() def freeze_bert_encoder(self): for p in self.bert.parameters(): p.requires_grad = False def unfreeze_bert_encoder(self): for p in self.bert.parameters(): p.requires_grad = True def forward(self, input_ids, token_type_ids, attention_mask, label_ids=None, output_all_encoded_layers=True): encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=output_all_encoded_layers) if config['feature-based'] != 'Last': if config['feature-based'] == 'Finetune_All': sequence_output = torch.cat(encoded_layers,2) elif config['feature-based'] == 'First': sequence_output = encoded_layers[0] elif config['feature-based'] == 'Second_to_Last': sequence_output = torch.cat(encoded_layers[1:],1) elif config['feature-based'] == 'Sum_Last_Four': sequence_output = sum(encoded_layers[-4:]) elif config['feature-based'] == 'Concat_Last_Four': sequence_output = torch.cat(encoded_layers[-4:],2) elif config['feature-based'] == 'Sum_All': sequence_output = sum(encoded_layers) pooled_output = self.pooler(sequence_output) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits
class BertForMultiLabelSequenceClassification(BertPreTrainedModel ): # type: ignore """Make a good docstring!""" def __init__(self, config: BertConfig, num_labels: int = 2): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward( self, input_ids: Tensor, token_type_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, pos_weight: Tensor = None, ) -> Tensor: _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: if pos_weight is None: loss_fct = BCEWithLogitsLoss() else: loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) return loss else: return logits def freeze_bert_encoder(self) -> None: for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self) -> None: for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel,): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config, num_labels=17, mobilebert = True): self.mobilebert = mobilebert if not mobilebert: super(BertForMultiLabelSequenceClassification, self).__init__(config) else: super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) if not mobilebert else MobileBertModel.from_pretrained( 'google/mobilebert-uncased', num_labels=num_labels,) self.dropout = torch.nn.Dropout( config.hidden_dropout_prob) self.classifier = torch.nn.Linear( config.hidden_size, num_labels) if not mobilebert: self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) zeros = torch.zeros_like(logits) ones = torch.ones_like(logits) labels = labels.to(torch.float) loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) return loss , logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config, num_labels=2): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.weight = torch.tensor([0.1, 0.1, 0.2, 0.6]) self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: loss_fct = torch.nn.BCEWithLogitsLoss() #BCEWithLogitsLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) return loss, logits else: return logits def freeze(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForSiameseClassification(BertPreTrainedModel): def __init__(self, config): super(BertForSiameseClassification, self).__init__(config) self.bert = BertModel(config) self.apply(self.init_bert_weights) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.cosVec = cosVec() def forward(self, input_ids_1, input_mask_1, input_ids_2, input_mask_2): encoder_layer_1, pooled_output_1 = self.bert( input_ids_1, token_type_ids=None, attention_mask=input_mask_1) encoder_layer_2, pooled_output_2 = self.bert( input_ids_2, token_type_ids=None, attention_mask=input_mask_2) sim = self.cosVec(pooled_output_1, pooled_output_2) return pooled_output_1, pooled_output_2, sim def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
def add_adapters(bert_model: BertModel, config: AdapterConfig) -> BertModel: bert_encoder = bert_model.encoder for i in range(len(bert_model.encoder.layer)): bert_encoder.layer[i].attention.output = adapt_bert_self_output( config)(bert_encoder.layer[i].attention.output) # Freeze all parameters for param in bert_model.parameters(): param.requires_grad = False # Unfreeze trainable parts — layer norms and adapters for name, sub_module in bert_model.named_modules(): if isinstance(sub_module, (Adapter, BertLayerNorm)): for param_name, param in sub_module.named_parameters(): param.requires_grad = True return bert_model
def __init__(self, bert_config: str, requires_grad: bool = False, dropout: float = 0.1, layer_dropout: float = 0.1, combine_layers: str = "mix") -> None: model = BertModel(BertConfig.from_json_file(bert_config)) for param in model.parameters(): param.requires_grad = requires_grad super().__init__(bert_model=model, layer_dropout=layer_dropout, combine_layers=combine_layers) self.model = model self.dropout = dropout self.set_dropout(dropout)
class BertForLabelEncoding(PreTrainedBertModel): def __init__(self, config, trainable=False): super(BertForLabelEncoding, self).__init__(config) self.config = config self.bert = BertModel(config) #self.apply(self.init_bert_weights) # don't need to perform due to pre-trained params loading if not trainable: for p in self.bert.parameters(): p.requires_grad = False def forward(self, input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers) return pooled_output
class BertLSTMForClassification(BertPreTrainedModel): def __init__(self, config, encoder, attention, hidden_dim, num_labels): super(BertForClassification, self).__init__(config) self.num_classes = num_labels self.bert = BertModel(config) self.encoder = encoder self.attention = attention self.decoder = nn.Linear(hidden_dim, num_labels) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs, hidden = self.encoder(encoded_layers) if isinstance(hidden, tuple): # LSTM hidden = hidden[1] # take the cell state if self.encoder.bidirectional: # need to concat the last 2 hidden layers hidden = torch.cat([hidden[-1], hidden[-2]], dim=1) else: hidden = hidden[-1] # max across T? # Other options (work worse on a few tests): # linear_combination, _ = torch.max(outputs, 0) # linear_combination = torch.mean(outputs, 0) energy, linear_combination = self.attention(hidden, outputs, outputs) logits = self.decoder(linear_combination) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss else: return logits def freeze_bert(self): for param in self.bert.parameters(): param.requires_grad = False
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_labels`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, num_labels] with indices selected in [0, ..., num_labels]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, num_labels]. """ def __init__(self, config, num_labels=2, loss_fct="bbce"): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.loss_fct = loss_fct self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: if self.loss_fct == "bbce": loss_fct = BalancedBCEWithLogitsLoss() else: loss_fct = torch.nn.MultiLabelSoftMarginLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) return loss else: return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertQAYesnoHierarchicalReinforceRACE(BertPreTrainedModel): """ Hard attention using reinforce learning """ def __init__(self, config, evidence_lambda=0.8, num_choices=4, sample_steps: int = 5, reward_func: int = 0, freeze_bert=False): super(BertQAYesnoHierarchicalReinforceRACE, self).__init__(config) logger.info(f'The model {self.__class__.__name__} is loading...') logger.info(f'The coefficient of evidence loss is {evidence_lambda}') logger.info(f'Currently the number of choices is {num_choices}') logger.info(f'Sample steps: {sample_steps}') logger.info(f'Reward function: {reward_func}') logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ') layers.set_seq_dropout(True) layers.set_my_dropout_prob(config.hidden_dropout_prob) rep_layers.set_seq_dropout(True) rep_layers.set_my_dropout_prob(config.hidden_dropout_prob) self.bert = BertModel(config) if freeze_bert: for p in self.bert.parameters(): p.requires_grad = False self.doc_sen_self_attn = rep_layers.LinearSelfAttention( config.hidden_size) self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size) self.word_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.vector_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) # self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3) self.classifier = nn.Linear(config.hidden_size * 2, 1) self.evidence_lam = evidence_lambda self.sample_steps = sample_steps self.reward_func = [self.reinforce_step, self.reinforce_step_1][reward_func] self.num_choices = num_choices self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sentence_span_list=None, sentence_ids=None, max_sentences: int = 0): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view( -1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view( -1, attention_mask.size(-1)) if attention_mask is not None else None sequence_output, _ = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list, max_sentences=max_sentences) batch, max_sen, doc_len = doc_sen_mask.size() que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1) doc = doc_sen.reshape(batch, max_sen * doc_len, -1) word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc) word_hidden = word_hidden.view(batch, max_sen, -1) doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1) sentence_sim = self.vector_similarity(que_vec, doc_vecs) if self.training: _sample_prob, _sample_log_prob = self.sample_one_hot( sentence_sim, sentence_mask) loss_and_reward, _ = self.reward_func(word_hidden, que_vec, labels, _sample_prob, _sample_log_prob) output_dict = {'loss': loss_and_reward} else: _prob, _ = self.sample_one_hot(sentence_sim, sentence_mask) loss, _choice_logits = self.simple_step(word_hidden, que_vec, labels, _prob) sentence_scores = rep_layers.masked_softmax(sentence_sim, sentence_mask, dim=-1).squeeze_(1) output_dict = { 'sentence_logits': sentence_scores.float(), 'loss': loss, 'choice_logits': _choice_logits.float() } return output_dict def sample_one_hot(self, _similarity, _mask): _probability = rep_layers.masked_softmax(_similarity, _mask) dtype = _probability.dtype _probability = _probability.float() # _log_probability = masked_log_softmax(_similarity, _mask) if self.training: _distribution = Categorical(_probability) _sample_index = _distribution.sample((self.sample_steps, )) logger.debug(str(_sample_index.size())) new_shape = (self.sample_steps, ) + _similarity.size() logger.debug(str(new_shape)) _sample_one_hot = F.one_hot(_sample_index, num_classes=_similarity.size(-1)) # _sample_one_hot = _similarity.new_zeros(new_shape).scatter(-1, _sample_index.unsqueeze(-1), 1.0) logger.debug(str(_sample_one_hot.size())) _log_prob = _distribution.log_prob( _sample_index) # sample_steps, batch, 1 assert _log_prob.size() == new_shape[:-1], (_log_prob.size(), new_shape) _sample_one_hot = _sample_one_hot.transpose( 0, 1) # batch, sample_steps, 1, max_sen _log_prob = _log_prob.transpose(0, 1) # batch, sample_steps, 1 return _sample_one_hot.to(dtype=dtype), _log_prob.to(dtype=dtype) else: _max_index = _probability.float().max(dim=-1, keepdim=True)[1] _one_hot = torch.zeros_like(_similarity).scatter_( -1, _max_index, 1.0) # _log_prob = _log_probability.gather(-1, _max_index) return _one_hot, None def reinforce_step(self, hidden, q_vec, label, prob, log_prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, self.sample_steps, 1, max_sen) assert log_prob.size() == (batch, self.sample_steps, 1) expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1, -1) h = prob.matmul(expanded_hidden).squeeze( 2) # batch, sample_steps, hidden_dim q = q_vec.expand(-1, self.sample_steps, -1) # _logits = self.classifier(torch.cat([h, q], dim=2)).view(-1, self.num_choices) # batch, sample_steps, 3 # Note the rank of dimension here _logits = self.classifier(torch.cat([h, q], dim=2)).view(label.size(0), self.num_choices, self.sample_steps)\ .transpose(1, 2).reshape(-1, self.num_choices) expanded_label = label.unsqueeze(1).expand( -1, self.sample_steps).reshape(-1) _loss = F.cross_entropy(_logits, expanded_label) corrects = (_logits.max(dim=-1)[1] == expanded_label).to(hidden.dtype) log_prob = log_prob.reshape(label.size(0), self.num_choices, self.sample_steps).transpose( 1, 2).mean(dim=-1) reward1 = (log_prob.reshape(-1) * corrects).sum() / (self.sample_steps * label.size(0)) return _loss - reward1, _logits def reinforce_step_1(self, hidden, q_vec, label, prob, log_prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, self.sample_steps, 1, max_sen) assert log_prob.size() == (batch, self.sample_steps, 1) expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1, -1) h = prob.matmul(expanded_hidden).squeeze( 2) # batch, sample_steps, hidden_dim q = q_vec.expand(-1, self.sample_steps, -1) # _logits = self.classifier(torch.cat([h, q], dim=2)).view(-1, self.num_choices) # batch * sample_steps, 3 _logits = self.classifier(torch.cat([h, q], dim=2)).view(label.size(0), self.num_choices, self.sample_steps)\ .transpose(1, 2).reshape(-1, self.num_choices) expanded_label = label.unsqueeze(1).expand( -1, self.sample_steps).reshape(-1) # batch * sample_steps _loss = F.cross_entropy(_logits, expanded_label) _final_log_prob = F.log_softmax(_logits, dim=-1) # ignore_mask = (expanded_label == -1) # expanded_label = expanded_label.masked_fill(ignore_mask, 0) selected_log_prob = _final_log_prob.gather( 1, expanded_label.unsqueeze(1)).squeeze(-1) # batch * sample_steps assert selected_log_prob.size() == ( label.size(0) * self.sample_steps, ), selected_log_prob.size() log_prob = log_prob.reshape(label.size(0), self.num_choices, self.sample_steps).transpose( 1, 2).mean(dim=-1) # reward2 = - (log_prob.reshape(-1) * (selected_log_prob * (1 - ignore_mask).to(log_prob.dtype))).sum() / ( # self.sample_steps * batch) reward2 = -(log_prob.reshape(-1) * selected_log_prob).sum() / ( self.sample_steps * label.size(0)) return _loss - reward2, _logits def simple_step(self, hidden, q_vec, label, prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, 1, max_sen) h = prob.bmm(hidden) _logits = self.classifier(torch.cat([h, q_vec], dim=2)).view(-1, self.num_choices) if label is not None: _loss = F.cross_entropy(_logits, label) else: _loss = _logits.new_zeros(1) return _loss, _logits
class BertQAYesnoHierarchicalHardRACE(BertPreTrainedModel): """ Hard: Hard attention, using gumbel softmax of reinforcement learning. """ def __init__(self, config, evidence_lambda=0.8, num_choices=4, use_gumbel=True, freeze_bert=False): super(BertQAYesnoHierarchicalHardRACE, self).__init__(config) logger.info(f'The model {self.__class__.__name__} is loading...') logger.info(f'The coefficient of evidence loss is {evidence_lambda}') logger.info(f'Currently the number of choices is {num_choices}') logger.info(f'Use gumbel: {use_gumbel}') logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ') layers.set_seq_dropout(True) layers.set_my_dropout_prob(config.hidden_dropout_prob) rep_layers.set_seq_dropout(True) rep_layers.set_my_dropout_prob(config.hidden_dropout_prob) self.bert = BertModel(config) if freeze_bert: for p in self.bert.parameters(): p.requires_grad = False # self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp(config.hidden_size) # self.que_self_attn = layers.LinearSelfAttn(config.hidden_size) self.doc_sen_self_attn = rep_layers.LinearSelfAttention( config.hidden_size) self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size) self.word_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.vector_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.classifier = nn.Linear(config.hidden_size * 2, 1) self.evidence_lam = evidence_lambda self.use_gumbel = use_gumbel self.num_choices = num_choices self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sentence_span_list=None, sentence_ids=None, max_sentences: int = 0): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view( -1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view( -1, attention_mask.size(-1)) if attention_mask is not None else None sequence_output, _ = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ rep_layers.split_doc_sen_que(sequence_output, flat_token_type_ids, flat_attention_mask, sentence_span_list, max_sentences=max_sentences) batch, max_sen, doc_len = doc_sen_mask.size() # que_len = que_mask.size(1) # que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1) que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1) doc = doc_sen.reshape(batch, max_sen * doc_len, -1) word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc) word_hidden = word_hidden.view(batch, max_sen, -1) doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1) sentence_sim = self.vector_similarity(que_vec, doc_vecs) sentence_hidden = self.hard_sample( sentence_sim, use_gumbel=self.use_gumbel, dim=-1, hard=True, mask=sentence_mask).bmm(word_hidden).squeeze(1) choice_logits = self.classifier( torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)).reshape(-1, self.num_choices) sentence_scores = rep_layers.masked_softmax(sentence_sim, sentence_mask, dim=-1).squeeze_(1) output_dict = { 'choice_logits': choice_logits.float(), 'sentence_logits': sentence_scores.reshape(choice_logits.size(0), self.num_choices, max_sen).detach().cpu().float(), } loss = 0 if labels is not None: choice_loss = F.cross_entropy(choice_logits, labels) loss += choice_loss if sentence_ids is not None: log_sentence_sim = rep_layers.masked_log_softmax( sentence_sim.squeeze(1), sentence_mask, dim=-1) sentence_loss = F.nll_loss(log_sentence_sim, sentence_ids.view(batch), reduction='sum', ignore_index=-1) loss += self.evidence_lam * sentence_loss / choice_logits.size(0) output_dict['loss'] = loss return output_dict def hard_sample(self, logits, use_gumbel, dim=-1, hard=True, mask=None): if use_gumbel: if self.training: probs = rep_layers.gumbel_softmax(logits, mask=mask, hard=hard, dim=dim) return probs else: probs = rep_layers.masked_softmax(logits, mask, dim=dim) index = probs.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) return y_hard else: pass
class SANBertNetwork(nn.Module): def __init__(self, opt, bert_config=None): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() self.bert_config = BertConfig.from_dict(opt) self.bert = BertModel(self.bert_config) if opt.get('dump_feature', False): self.opt = opt return if opt['update_bert_opt'] > 0: for p in self.bert.parameters(): p.requires_grad = False mem_size = self.bert_config.hidden_size self.decoder_opt = opt['answer_opt'] self.scoring_list = nn.ModuleList() labels = [int(ls) for ls in opt['label_size'].split(',')] task_dropout_p = opt['tasks_dropout_p'] self.bert_pooler = None for task, lab in enumerate(labels): decoder_opt = self.decoder_opt[task] dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout']) self.dropout_list.append(dropout) if decoder_opt == 1: out_proj = SANClassifier(mem_size, mem_size, lab, opt, prefix='answer', dropout=dropout) self.scoring_list.append(out_proj) else: out_proj = nn.Linear(self.bert_config.hidden_size, lab) self.scoring_list.append(out_proj) self.opt = opt self._my_init() self.set_embed(opt) def _my_init(self): def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range * self.opt['init_ratio']) elif isinstance(module, BertLayerNorm): # Slightly different from the BERT pytorch version, which should be a bug. # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@. # Layer normalization (https://arxiv.org/abs/1607.06450) # support both old/latest version if 'beta' in dir(module) and 'gamma' in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear): module.bias.data.zero_() self.apply(init_weights) def nbert_layer(self): return len(self.bert.encoder.layer) def freeze_layers(self, max_n): assert max_n < self.nbert_layer() for i in range(0, max_n): self.freeze_layer(i) def freeze_layer(self, n): assert n < self.nbert_layer() layer = self.bert.encoder.layer[n] for p in layer.parameters(): p.requires_grad = False def set_embed(self, opt): bert_embeddings = self.bert.embeddings emb_opt = opt['embedding_opt'] if emb_opt == 1: for p in bert_embeddings.word_embeddings.parameters(): p.requires_grad = False elif emb_opt == 2: for p in bert_embeddings.position_embeddings.parameters(): p.requires_grad = False elif emb_opt == 3: for p in bert_embeddings.token_type_embeddings.parameters(): p.requires_grad = False elif emb_opt == 4: for p in bert_embeddings.token_type_embeddings.parameters(): p.requires_grad = False for p in bert_embeddings.position_embeddings.parameters(): p.requires_grad = False def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0): all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) sequence_output = all_encoder_layers[-1] if self.bert_pooler is not None: pooled_output = self.bert_pooler(sequence_output) decoder_opt = self.decoder_opt[task_id] if decoder_opt == 1: max_query = hyp_mask.size(1) assert max_query > 0 assert premise_mask is not None assert hyp_mask is not None hyp_mem = sequence_output[:, :max_query, :] logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask) else: pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits
class Bert_CRF(BertPreTrainedModel): def __init__(self, config, num_tag): super(Bert_CRF, self).__init__(config) self.bert = BertModel(config) if args.do_not_train_ernie: for p in self.bert.parameters(): p.requires_grad = False self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_tag) self.apply(self.init_bert_weights) self.crf = CRF(num_tag) self.num_tag = num_tag def forward(self, input_ids, token_type_ids, attention_mask, label_id=None, output_all_encoded_layers=False): bert_encode, _ = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=output_all_encoded_layers) output = self.classifier(bert_encode) return output def loss_fn(self, bert_encode, output_mask, tags): if args.do_CRF: loss = self.crf.negative_log_loss(bert_encode, output_mask, tags) else: loss = torch.autograd.Variable(torch.tensor(0.), requires_grad=True) for ix, (features, tag) in enumerate(zip(bert_encode, tags)): num_valid = torch.sum(output_mask[ix].detach()) features = features[output_mask[ix] == 1] tag = tag[:num_valid] loss_fct = nn.CrossEntropyLoss(ignore_index=0) loss = loss + loss_fct( features.view(-1, self.num_tag).cpu(), tag.view(-1).cpu()) return loss def predict(self, bert_encode, output_mask): if args.do_CRF: predicts = self.crf.get_batch_best_path(bert_encode, output_mask) if not args.do_inference: predicts = predicts.view(1, -1).squeeze() predicts = predicts[predicts != -1] else: predicts_ = [] for ix, features, in enumerate(predicts): #features = features[output_mask[ix] == 1] predict = features[features != -1] predicts_.append(predict) predicts = predicts_ else: predicts_ = [] for ix, features, in enumerate(bert_encode): features = features[output_mask[ix] == 1] predict = F.softmax(features, dim=1) predict = torch.argmax(predict, dim=1) predicts_.append(predict) if not args.do_inference: predicts = torch.cat(predicts_, 0) else: predicts = predicts_ return predicts def acc_f1(self, y_pred, y_true): try: y_pred = y_pred.numpy() y_true = y_true.numpy() except: pass f1 = f1_score(y_true, y_pred, average="macro") correct = np.sum((y_true == y_pred).astype(int)) acc = correct / y_pred.shape[0] return acc, f1 def class_report(self, y_pred, y_true): y_true = y_true.numpy() y_pred = y_pred.numpy() classify_report = classification_report(y_true, y_pred) print('\n\nclassify_report:\n', classify_report)
class BertQAYesnoHierarchicalReinforce(BertPreTrainedModel): """ Hard attention using reinforce learning """ def __init__(self, config, evidence_lambda=0.8, sample_steps: int = 5, reward_func: int = 0, freeze_bert=False): super(BertQAYesnoHierarchicalReinforce, self).__init__(config) logger.info(f'The model {self.__class__.__name__} is loading...') logger.info(f'The coefficient of evidence loss is {evidence_lambda}') logger.info(f'Sample steps: {sample_steps}') logger.info(f'Reward function: {reward_func}') logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ') layers.set_seq_dropout(True) layers.set_my_dropout_prob(config.hidden_dropout_prob) self.bert = BertModel(config) if freeze_bert: for p in self.bert.parameters(): p.requires_grad = False self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp( config.hidden_size) self.que_self_attn = layers.LinearSelfAttn(config.hidden_size) self.word_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.vector_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3) self.evidence_lam = evidence_lambda self.sample_steps = sample_steps self.reward_func = [self.reinforce_step, self.reinforce_step_1][reward_func] self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None, sentence_span_list=None, sentence_ids=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list) batch, max_sen, doc_len = doc_sen_mask.size() # que_len = que_mask.size(1) que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view( batch, 1, -1) doc = doc_sen.reshape(batch, max_sen * doc_len, -1) # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len] word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h] word_hidden = masked_softmax(word_sim, 1 - doc_mask, dim=1).unsqueeze(1).bmm(doc) word_hidden = word_hidden.view(batch, max_sen, -1) doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view( batch, max_sen, -1) # [batch, 1, h] # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1) # [batch, 1, max_sen] sentence_sim = self.vector_similarity(que_vec, doc_vecs) # sentence_hidden = self.hard_sample(sentence_sim, use_gumbel=self.use_gumbel, dim=-1, # hard=True, mask=(1 - sentence_mask)).bmm(word_hidden).squeeze(1) if self.training: _sample_prob, _sample_log_prob = self.sample_one_hot( sentence_sim, 1 - sentence_mask) loss_and_reward, _ = self.reward_func(word_hidden, que_vec, answer_choice, _sample_prob, _sample_log_prob) output_dict = {'loss': loss_and_reward} else: _prob, _ = self.sample_one_hot(sentence_sim, 1 - sentence_mask) loss, _yesno_logits = self.simple_step(word_hidden, que_vec, answer_choice, _prob) sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1) output_dict = { 'max_weight': sentence_scores.max(dim=1)[0], 'max_weight_index': sentence_scores.max(dim=1)[1], 'sentence_logits': sentence_scores, 'loss': loss, 'yesno_logits': _yesno_logits } return output_dict # yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)) # # sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1) # output_dict = {'yesno_logits': yesno_logits, # 'sentence_logits': sentence_scores, # 'max_weight_index': sentence_scores.max(dim=1)[1], # 'max_weight': sentence_scores.max(dim=1)[0]} # loss = 0 # if answer_choice is not None: # choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1) # loss += choice_loss # if sentence_ids is not None: # log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1) # sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1) # loss += sentence_loss # output_dict['loss'] = loss # return output_dict def sample_one_hot(self, _similarity, _mask): _probability = masked_softmax(_similarity, _mask) # _log_probability = masked_log_softmax(_similarity, _mask) if self.training: _distribution = Categorical(_probability) _sample_index = _distribution.sample((self.sample_steps, )) new_shape = (self.sample_steps, ) + _similarity.size() _sample_one_hot = _similarity.new_zeros(new_shape).scatter( -1, _sample_index.unsqueeze(-1), 1.0) _log_prob = _distribution.log_prob( _sample_index) # sample_steps, batch, 1 assert _log_prob.size() == new_shape[:-1], (_log_prob.size(), new_shape) _sample_one_hot = _sample_one_hot.transpose( 0, 1) # batch, sample_steps, 1, max_sen _log_prob = _log_prob.transpose(0, 1) # batch, sample_steps, 1 return _sample_one_hot, _log_prob else: _max_index = _probability.max(dim=-1, keepdim=True)[1] _one_hot = torch.zeros_like(_similarity).scatter_( -1, _max_index, 1.0) # _log_prob = _log_probability.gather(-1, _max_index) return _one_hot, None def reinforce_step(self, hidden, q_vec, label, prob, log_prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, self.sample_steps, 1, max_sen) assert log_prob.size() == (batch, self.sample_steps, 1) expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1, -1) h = prob.matmul(expanded_hidden).squeeze( 2) # batch, sample_steps, hidden_dim q = q_vec.expand(-1, self.sample_steps, -1) _logits = self.yesno_predictor(torch.cat([h, q], dim=2)).view( -1, 3) # batch, sample_steps, 3 expanded_label = label.unsqueeze(1).expand( -1, self.sample_steps).reshape(-1) _loss = F.cross_entropy(_logits, expanded_label) corrects = (_logits.max(dim=-1)[1] == expanded_label).to(hidden.dtype) reward1 = (log_prob.reshape(-1) * corrects).sum() / (self.sample_steps * batch) return _loss - reward1, _logits def reinforce_step_1(self, hidden, q_vec, label, prob, log_prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, self.sample_steps, 1, max_sen) assert log_prob.size() == (batch, self.sample_steps, 1) expanded_hidden = hidden.unsqueeze(1).expand(-1, self.sample_steps, -1, -1) h = prob.matmul(expanded_hidden).squeeze( 2) # batch, sample_steps, hidden_dim q = q_vec.expand(-1, self.sample_steps, -1) _logits = self.yesno_predictor(torch.cat([h, q], dim=2)).view( -1, 3) # batch * sample_steps, 3 expanded_label = label.unsqueeze(1).expand( -1, self.sample_steps).reshape(-1) # batch * sample_steps _loss = F.cross_entropy(_logits, expanded_label) _final_log_prob = F.log_softmax(_logits, dim=-1) ignore_mask = (expanded_label == -1) expanded_label = expanded_label.masked_fill(ignore_mask, 0) selected_log_prob = _final_log_prob.gather( 1, expanded_label.unsqueeze(1)).squeeze(-1) assert selected_log_prob.size() == ( batch * self.sample_steps, ), selected_log_prob.size() reward2 = -(log_prob.reshape(-1) * (selected_log_prob * (1 - ignore_mask).to(log_prob.dtype))).sum() / ( self.sample_steps * batch) return _loss - reward2, _logits def simple_step(self, hidden, q_vec, label, prob): batch, max_sen, hidden_dim = hidden.size() assert q_vec.size() == (batch, 1, hidden_dim) assert prob.size() == (batch, 1, max_sen) h = prob.bmm(hidden) _logits = self.yesno_predictor(torch.cat([h, q_vec], dim=2)).view(-1, 3) if label is not None: _loss = F.cross_entropy(_logits, label) else: _loss = _logits.new_zeros(1) return _loss, _logits
class BertForMultiLabelClassification(BertPreTrainedModel): def __init__(self, config, num_labels=20): super(BertForMultiLabelClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.num_capsule = 10 self.dim_capsule = 16 self.caps = Caps_Layer(batch_size=12, input_dim_capsule=config.hidden_size, num_capsule=10, dim_capsule=16, routings=5) self.dense = nn.Linear(self.num_capsule * self.dim_capsule, num_labels) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): last_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) # last_output = torch.cuda.FloatTensor(last_output) # attention_mask = torch.cuda.FloatTensor(attention_mask) pooled_output = torch.sum( last_output * attention_mask.float().unsqueeze(2), dim=1) / torch.sum(attention_mask.float(), dim=1, keepdim=True) ''' batch_size = input_ids.size(0) caps_output = self.caps(last_output) # (batch_size, num_capsule, dim_capsule) caps_output = caps_output.view(batch_size, -1) # (batch_size, num_capsule*dim_capsule) caps_dropout = self.dropout(caps_output) logits = self.dense(caps_dropout) ''' pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: # loss_fct = BCEWithLogitsLoss() # loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) alpha = 0.75 gamma = 3 # focal loss x = logits.view(-1, self.num_labels) t = labels.view(-1, self.num_labels) ''' p = x.sigmoid() pt = p*t + (1-p)*(1-t) w = alpha*t + (1-alpha)*(1-t) w = w*(1-pt).pow(gamma) # return F.binary_cross_entropy_with_logits(x, t, w, size_average=False) return binary_cross_entropy(x, t, weight=w, smooth_eps=0.1, from_logits=True) ''' loss_fct = FocalLoss(logits=True) loss = loss_fct(x, t) return loss else: return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_labels`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_labels]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, num_labels]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 model = BertForSequenceClassification(config, num_labels) logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, num_labels=1): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertQAYesnoHierarchicalHardFP16(BertPreTrainedModel): """ Hard: Hard attention, using gumbel softmax of reinforcement learning. """ def __init__(self, config, evidence_lambda=0.8, use_gumbel=True, freeze_bert=False): super(BertQAYesnoHierarchicalHardFP16, self).__init__(config) logger.info(f'The model {self.__class__.__name__} is loading...') logger.info(f'The coefficient of evidence loss is {evidence_lambda}') logger.info(f'Use gumbel: {use_gumbel}') logger.info(f'If freeze BERT\'s parameters: {freeze_bert} ') layers.set_seq_dropout(True) layers.set_my_dropout_prob(config.hidden_dropout_prob) rep_layers.set_seq_dropout(True) rep_layers.set_my_dropout_prob(config.hidden_dropout_prob) self.bert = BertModel(config) if freeze_bert: for p in self.bert.parameters(): p.requires_grad = False self.doc_sen_self_attn = rep_layers.LinearSelfAttention( config.hidden_size) self.que_self_attn = rep_layers.LinearSelfAttention(config.hidden_size) self.word_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.vector_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3) self.evidence_lam = evidence_lambda self.use_gumbel = use_gumbel self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None, sentence_span_list=None, sentence_ids=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ rep_layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list) batch, max_sen, doc_len = doc_sen_mask.size() que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1) doc = doc_sen.reshape(batch, max_sen * doc_len, -1) # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len] word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc) word_hidden = word_hidden.view(batch, max_sen, -1) doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1) # [batch, 1, h] # sentence_hidden = self.vector_similarity(que_vec, doc_vecs, x2_mask=sentence_mask, x3=word_hidden).squeeze(1) # [batch, 1, max_sen] sentence_sim = self.vector_similarity(que_vec, doc_vecs) sentence_hidden = self.hard_sample( sentence_sim, use_gumbel=self.use_gumbel, dim=-1, hard=True, mask=sentence_mask).bmm(word_hidden).squeeze(1) yesno_logits = self.yesno_predictor( torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)) sentence_scores = rep_layers.masked_softmax(sentence_sim, sentence_mask, dim=-1).squeeze_(1) output_dict = { 'yesno_logits': torch.softmax(yesno_logits, dim=-1).detach().cpu().float(), 'sentence_logits': sentence_scores } loss = 0 if answer_choice is not None: choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1) loss += choice_loss # if sentence_ids is not None: # log_sentence_sim = rep_layers.masked_log_softmax(sentence_sim.squeeze(1), sentence_mask, dim=-1) # sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1) # loss += sentence_loss output_dict['loss'] = loss return output_dict def hard_sample(self, logits, use_gumbel, dim=-1, hard=True, mask=None): if use_gumbel: if self.training: probs = rep_layers.gumbel_softmax(logits, mask=mask, hard=hard, dim=dim) return probs else: probs = rep_layers.masked_softmax(logits, mask, dim=dim) index = probs.float().max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) return y_hard else: pass
class SANBertNetwork(nn.Module): def __init__(self, opt, bert_config=None, use_parse=False, embedding_matrix=None, token2idx=None, stx_parse_dim=None, unked_words=None, use_generic_features=False, num_generic_features=None, use_domain_features=False, num_domain_features=None, feature_dim=None): super(SANBertNetwork, self).__init__() self.dropout_list = [] self.bert_config = BertConfig.from_dict(opt) self.bert = BertModel(self.bert_config) if opt['update_bert_opt'] > 0: for p in self.bert.parameters(): p.requires_grad = False mem_size = self.bert_config.hidden_size self.scoring_list = nn.ModuleList() labels = [int(ls) for ls in opt['label_size'].split(',')] task_dropout_p = opt['tasks_dropout_p'] self.bert_pooler = None self.use_parse = use_parse self.stx_parse_dim = stx_parse_dim self.use_generic_features = use_generic_features self.use_domain_features = use_domain_features clf_dim = self.bert_config.hidden_size if self.use_parse: self.treelstm = BinaryTreeLSTM(self.stx_parse_dim, embedding_matrix.clone(), token2idx, unked_words=unked_words) parse_clf_dim = self.stx_parse_dim * 2 clf_dim += parse_clf_dim self.parse_clf = nn.Linear(parse_clf_dim, labels[0]) if self.use_generic_features: self.generic_feature_proj = nn.Linear(num_generic_features, num_generic_features * feature_dim) generic_feature_clf_dim = num_generic_features * feature_dim clf_dim += generic_feature_clf_dim self.generic_feature_clf = nn.Linear(generic_feature_clf_dim, labels[0]) if self.use_domain_features: self.domain_feature_proj = nn.Linear(num_domain_features, num_domain_features * feature_dim) domain_feature_clf_dim = num_domain_features * feature_dim clf_dim += domain_feature_clf_dim self.domain_feature_clf = nn.Linear(domain_feature_clf_dim, labels[0]) assert len(labels) == 1 for task, lab in enumerate(labels): dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout']) self.dropout_list.append(dropout) out_proj = nn.Linear(self.bert_config.hidden_size, lab) self.scoring_list.append(out_proj) self.opt = opt self._my_init() self.set_embed(opt) if embedding_matrix is not None and self.use_parse: self.treelstm.embedding.weight = nn.Parameter(embedding_matrix) # set again b/c self._my_init() overwrites it def _my_init(self): def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range * self.opt['init_ratio']) elif isinstance(module, BertLayerNorm): # Slightly different from the BERT pytorch version, which should be a bug. # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@. # Layer normalization (https://arxiv.org/abs/1607.06450) # support both old/latest version if 'beta' in dir(module) and 'gamma' in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear): module.bias.data.zero_() self.apply(init_weights) def nbert_layer(self): return len(self.bert.encoder.layer) def freeze_layers(self, max_n): assert max_n < self.nbert_layer() for i in range(0, max_n): self.freeze_layer(i) def freeze_layer(self, n): assert n < self.nbert_layer() layer = self.bert.encoder.layer[n] for p in layer.parameters(): p.requires_grad = False def set_embed(self, opt): bert_embeddings = self.bert.embeddings emb_opt = opt['embedding_opt'] if emb_opt == 1: for p in bert_embeddings.word_embeddings.parameters(): p.requires_grad = False elif emb_opt == 2: for p in bert_embeddings.position_embeddings.parameters(): p.requires_grad = False elif emb_opt == 3: for p in bert_embeddings.token_type_embeddings.parameters(): p.requires_grad = False elif emb_opt == 4: for p in bert_embeddings.token_type_embeddings.parameters(): p.requires_grad = False for p in bert_embeddings.position_embeddings.parameters(): p.requires_grad = False def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0, bin_parse_as=None, bin_parse_bs=None, parse_as_mask=None, parse_bs_mask=None, generic_features=None, domain_features=None): all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) sequence_output = all_encoder_layers[-1] if self.bert_pooler is not None: pooled_output = self.bert_pooler(sequence_output) pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) if self.use_parse: parse_embeddings = torch.FloatTensor(len(input_ids), self.stx_parse_dim * 2).to(input_ids.device) assert len(bin_parse_as) == len(bin_parse_bs) == len(parse_as_mask) == len(parse_bs_mask) for i, (parse_a, parse_b, parse_a_mask, parse_b_mask) in enumerate(zip(bin_parse_as, bin_parse_bs, parse_as_mask, parse_bs_mask)): parse_a = parse_a[:parse_a_mask.sum()] parse_b = parse_b[:parse_b_mask.sum()] t = Tree.from_char_indices(parse_a) parse_embeddings[i,:self.stx_parse_dim] = self.treelstm(t)[1] t = Tree.from_char_indices(parse_b) parse_embeddings[i,self.stx_parse_dim:] = self.treelstm(t)[1] logits += self.parse_clf(self.dropout_list[task_id](parse_embeddings)) if self.use_generic_features: # features: bsz * n_features generic_feature_embeddings = F.relu(self.generic_feature_proj(generic_features)) logits += self.generic_feature_clf(self.dropout_list[task_id](generic_feature_embeddings)) if self.use_domain_features: # features: bsz * n_features domain_feature_embeddings = F.relu(self.domain_feature_proj(domain_features)) logits += self.domain_feature_clf(self.dropout_list[task_id](domain_feature_embeddings)) return logits
class SANBertNetwork(nn.Module): def __init__(self, opt, bert_config=None): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() self.encoder_type = opt['encoder_type'] if opt['encoder_type'] == EncoderModelType.ROBERTA: from fairseq.models.roberta import RobertaModel self.bert = RobertaModel.from_pretrained(opt['init_checkpoint']) hidden_size = self.bert.args.encoder_embed_dim self.pooler = LinearPooler(hidden_size) else: self.bert_config = BertConfig.from_dict(opt) self.bert = BertModel(self.bert_config) hidden_size = self.bert_config.hidden_size if opt.get('dump_feature', False): self.opt = opt return if opt['update_bert_opt'] > 0: for p in self.bert.parameters(): p.requires_grad = False self.decoder_opt = opt['answer_opt'] self.task_types = opt["task_types"] self.scoring_list = nn.ModuleList() labels = [int(ls) for ls in opt['label_size'].split(',')] task_dropout_p = opt['tasks_dropout_p'] for task, lab in enumerate(labels): decoder_opt = self.decoder_opt[task] task_type = self.task_types[task] dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout']) self.dropout_list.append(dropout) if task_type == TaskType.Span: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SeqenceLabeling: out_proj = nn.Linear(hidden_size, lab) else: if decoder_opt == 1: out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout) else: out_proj = nn.Linear(hidden_size, lab) self.scoring_list.append(out_proj) self.opt = opt self._my_init() def _my_init(self): def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02 * self.opt['init_ratio']) elif isinstance(module, BertLayerNorm): # Slightly different from the BERT pytorch version, which should be a bug. # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@. # Layer normalization (https://arxiv.org/abs/1607.06450) # support both old/latest version if 'beta' in dir(module) and 'gamma' in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear): module.bias.data.zero_() self.apply(init_weights) def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0): if attention_mask is not None and attention_mask.dtype == torch.uint8: attention_mask = attention_mask.bool() if premise_mask is not None and premise_mask.dtype == torch.uint8: premise_mask = premise_mask.bool() if hyp_mask is not None and hyp_mask.dtype == torch.uint8: hyp_mask = hyp_mask.bool() if self.encoder_type == EncoderModelType.ROBERTA: sequence_output = self.bert.extract_features(input_ids) pooled_output = self.pooler(sequence_output) else: all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) sequence_output = all_encoder_layers[-1] decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] if task_type == TaskType.Span: assert decoder_opt != 1 sequence_output = self.dropout_list[task_id](sequence_output) logits = self.scoring_list[task_id](sequence_output) start_scores, end_scores = logits.split(1, dim=-1) start_scores = start_scores.squeeze(-1) end_scores = end_scores.squeeze(-1) return start_scores, end_scores elif task_type == TaskType.SeqenceLabeling: pooled_output = all_encoder_layers[-1] pooled_output = self.dropout_list[task_id](pooled_output) pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2)) logits = self.scoring_list[task_id](pooled_output) return logits else: if decoder_opt == 1: max_query = hyp_mask.size(1) assert max_query > 0 assert premise_mask is not None assert hyp_mask is not None hyp_mem = sequence_output[:, :max_query, :] logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask) else: pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits
class BertQAYesnoHierarchicalSingle(BertPreTrainedModel): """ BertForQuestionAnsweringForYesNo Model Hierarchical Attention: - Use Hierarchical attention module to predict Non/Yes/No. - Add supervised to sentence attention. Sentence level model. """ def __init__(self, config, evidence_lambda=0.8, negative_lambda=1.0, add_entropy: bool = False, fix_bert: bool = False): super(BertQAYesnoHierarchicalSingle, self).__init__(config) logger.info(f'The model {self.__class__.__name__} is loading...') logger.info(f'The coefficient of evidence loss is {evidence_lambda}') logger.info( f'The coefficient of negative samples loss is {negative_lambda}') logger.info(f'Fix parameters of BERT: {fix_bert}') logger.info(f'Add entropy loss: {add_entropy}') # logger.info(f'Use bidirectional attention before summarizing vectors: {bi_attention}') layers.set_seq_dropout(True) layers.set_my_dropout_prob(config.hidden_dropout_prob) self.bert = BertModel(config) # self.dropout = nn.Dropout(config.hidden_dropout_prob) # self.answer_choice = nn.Linear(config.hidden_size, 2) if fix_bert: for param in self.bert.parameters(): param.requires_grad = False self.doc_sen_self_attn = layers.LinearSelfAttnAllennlp( config.hidden_size) self.que_self_attn = layers.LinearSelfAttn(config.hidden_size) self.word_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) self.vector_similarity = layers.AttentionScore(config.hidden_size, 250, do_similarity=False) # self.doc_sen_encoder = layers.StackedBRNN(config.hidden_size, 125, num_layers=1) # self.yesno_predictor = nn.Linear(config.hidden_size, 2) self.yesno_predictor = nn.Linear(config.hidden_size * 2, 3) self.evidence_lam = evidence_lambda self.negative_lam = negative_lambda self.add_entropy = add_entropy self.apply(self.init_bert_weights) def forward(self, ques_input_ids, ques_input_mask, pass_input_ids, pass_input_mask, answer_choice=None, sentence_ids=None, sentence_label=None): # Encoding question q_len = ques_input_ids.size(1) question, _ = self.bert(ques_input_ids, token_type_ids=None, attention_mask=ques_input_mask, output_all_encoded_layers=False) # Encoding passage batch, max_sen_num, p_len = pass_input_ids.size() pass_input_ids = pass_input_ids.reshape(batch * max_sen_num, p_len) pass_input_mask = pass_input_mask.reshape(batch * max_sen_num, p_len) passage, _ = self.bert(pass_input_ids, token_type_ids=None, attention_mask=pass_input_mask, output_all_encoded_layers=False) que_mask = (1 - ques_input_mask).byte() que_vec = layers.weighted_avg(question, self.que_self_attn(question, que_mask)).view( batch, 1, -1) doc = passage.reshape(batch, max_sen_num * p_len, -1) # [batch, max_sen, doc_len] -> [batch * max_sen, doc_len] word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen_num, p_len) # doc_mask = 1 - pass_input_mask doc_mask = pass_input_ids # 1 for true value and 0 for mask # [batch * max_sen, doc_len] -> [batch * max_sen, 1, doc_len] -> [batch * max_sen, 1, h] word_hidden = masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(passage) word_hidden = word_hidden.view(batch, max_sen_num, -1) sentence_mask = pass_input_mask.reshape( batch, max_sen_num, p_len).sum(dim=-1).ge(1.0).float() # 1 - doc_mask: 0 for true value and 1 for mask doc_vecs = layers.weighted_avg( passage, self.doc_sen_self_attn(passage, 1 - doc_mask)).view(batch, max_sen_num, -1) # [batch, 1, max_sen] sentence_sim = self.vector_similarity(que_vec, doc_vecs) # sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask) sentence_scores = masked_softmax( sentence_sim, sentence_mask) # 1 for true value and 0 for mask sentence_hidden = sentence_scores.bmm(word_hidden).squeeze(1) yesno_logits = self.yesno_predictor( torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)) sentence_scores = sentence_scores.squeeze(1) max_sentence_score = sentence_scores.max(dim=-1) output_dict = { 'yesno_logits': yesno_logits, 'sentence_logits': sentence_scores, 'max_weight': max_sentence_score[0], 'max_weight_index': max_sentence_score[1] } loss = 0 if answer_choice is not None: choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1) loss += choice_loss if sentence_ids is not None: log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), sentence_mask, dim=-1) sentence_loss = self.evidence_lam * F.nll_loss( log_sentence_sim, sentence_ids, ignore_index=-1) loss += sentence_loss if self.add_entropy: no_evidence_mask = (sentence_ids != -1) entropy = layers.get_masked_entropy(sentence_scores, mask=no_evidence_mask) loss += self.evidence_lam * entropy if sentence_label is not None: # sentence_label: batch * List[k] # [batch, max_sen] # log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1) sentence_prob = 1 - sentence_scores log_sentence_sim = -torch.log(sentence_prob + 1e-15) negative_loss = 0 for b in range(batch): for sen_id, k in enumerate(sentence_label[b]): negative_loss += k * log_sentence_sim[b][sen_id] negative_loss /= batch loss += self.negative_lam * negative_loss output_dict['loss'] = loss return output_dict