def __init__(self, config, train_steps=1200000): super(BertQueryNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = SingleNonLinearClassifier(config.hidden_size, 2, config.dropout) self.end_outputs = SingleNonLinearClassifier(config.hidden_size, 2, config.dropout) self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.dropout) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.train_steps = train_steps self.loss_wb = config.weight_start self.loss_we = config.weight_end self.loss_ws = config.weight_span self.device = torch.device("cuda") self.loss_type = config.loss_type if "dynamic_wce" in self.loss_type: start_sig = torch.empty(1) end_sig = torch.empty(1) span_sig = torch.empty(1) # test different init scale self._start_loss_sig = nn.init.normal_(start_sig, ).to(self.device) self._end_loss_sig = nn.init.normal_(end_sig, ).to(self.device) self._span_loss_sig = nn.init.normal_(span_sig, ).to(self.device)
def __init__(self, config): super(BertMRCNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.cluster_layer = config.cluster_layer
def __init__(self, config): super(BertQueryNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) # self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.dropout) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.loss_wb = config.weight_start self.loss_we = config.weight_end self.loss_ws = config.weight_span
def __init__(self, config): super(BertMRCNER_CLUSTER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) self.cluster_classify = nn.Linear(config.hidden_size, config.num_clusters) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.margin = config.margin self.gama = config.gama self.cluster_layer = config.cluster_layer self.pool_mode = config.pool_mode self.drop = nn.Dropout(config.dropout_rate)
class BertQueryNER(nn.Module): def __init__(self, config): super(BertQueryNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) # self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.dropout) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.loss_wb = config.weight_start self.loss_we = config.weight_end self.loss_ws = config.weight_span def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, span_positions=None): """ Args: start_positions: (batch x max_len x 1) [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] end_positions: (batch x max_len x 1) [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] span_positions: (batch x max_len x max_len) span_positions[k][i][j] is one of [0, 1], span_positions[k][i][j] represents whether or not from start_pos{i} to end_pos{j} of the K-th sentence in the batch is an entity. """ sequence_output, pooled_output, _ = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_heatmap = sequence_output # batch x seq_len x hidden batch_size, seq_len, hid_size = sequence_heatmap.size() start_logits = self.start_outputs( sequence_heatmap) # batch x seq_len x 2 end_logits = self.end_outputs(sequence_heatmap) # batch x seq_len x 2 # for every position $i$ in sequence, should concate $j$ to # predict if $i$ and $j$ are start_pos and end_pos for an entity. # start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) # end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) # the shape of start_end_concat[0] is : batch x 1 x seq_len x 2*hidden # span_matrix = torch.cat([start_extend, end_extend], 3) # batch x seq_len x seq_len x 2*hidden # span_logits = self.span_embedding(span_matrix) # batch x seq_len x seq_len x 1 # span_logits = torch.squeeze(span_logits) # batch x seq_len x seq_len if start_positions is not None and end_positions is not None: loss_fct = CrossEntropyLoss() start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1)) end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1)) # span_loss_fct = nn.BCEWithLogitsLoss() # span_loss = span_loss_fct(span_logits.view(batch_size, -1), span_positions.view(batch_size, -1).float()) # total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss total_loss = self.loss_wb * start_loss + self.loss_we * end_loss return total_loss else: # span_logits = torch.sigmoid(span_logits) # batch x seq_len x seq_len start_logits = torch.argmax(start_logits, dim=-1) end_logits = torch.argmax(end_logits, dim=-1) # return start_logits, end_logits, span_logits return start_logits, end_logits
class BertMRCNER(nn.Module): """ Desc: BERT model for question answering (span_extraction) This Module is composed of the BERT model with a linear on top of the sequence output that compute start_logits, and end_logits. Params: config: a BertConfig class instance with the configuration to build a new model. Inputs: input_ids: torch.LongTensor. of shape [batch_size, sequence_length] token_type_ids: an optional torch.LongTensor, [batch_size, sequence_length] of the token type [0, 1]. Type 0 corresponds to sentence A, Type 1 corresponds to sentence B. attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length] with index select [0, 1]. it is a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. start_positions: positions of the first token for the labeled span. torch.LongTensor of shape [batch_size, seq_len], if current position is start of entity, the value equals to 1. else the value equals to 0. end_position: position to the last token for the labeled span. torch.LongTensor, [batch_size, seq_len] Outputs: if "start_positions" and "end_positions" are not None output the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. if "start_positon" or "end_positions" is None """ def __init__(self, config): super(BertMRCNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.cluster_layer = config.cluster_layer def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): sequence_output, _, _, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_output = sequence_output.view(-1, self.hidden_size) start_logits = self.start_outputs(sequence_output) end_logits = self.end_outputs(sequence_output) if start_positions is not None and end_positions is not None: loss_fct = CrossEntropyLoss() start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1)) end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1)) # total_loss = start_loss + end_loss + span_loss total_loss = (start_loss + end_loss) / 2 return total_loss else: return start_logits, end_logits
class BertMRCNER_CLUSTER(nn.Module): """ Desc: BERT model for question answering (span_extraction) This Module is composed of the BERT model with a linear on top of the sequence output that compute start_logits, and end_logits. Params: config: a BertConfig class instance with the configuration to build a new model. Inputs: input_ids: torch.LongTensor. of shape [batch_size, sequence_length] token_type_ids: an optional torch.LongTensor, [batch_size, sequence_length] of the token type [0, 1]. Type 0 corresponds to sentence A, Type 1 corresponds to sentence B. attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length] with index select [0, 1]. it is a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. start_positions: positions of the first token for the labeled span. torch.LongTensor of shape [batch_size, seq_len], if current position is start of entity, the value equals to 1. else the value equals to 0. end_position: position to the last token for the labeled span. torch.LongTensor, [batch_size, seq_len] Outputs: if "start_positions" and "end_positions" are not None output the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. if "start_positon" or "end_positions" is None """ def __init__(self, config): super(BertMRCNER_CLUSTER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = nn.Linear(config.hidden_size, 2) self.end_outputs = nn.Linear(config.hidden_size, 2) self.cluster_classify = nn.Linear(config.hidden_size, config.num_clusters) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.margin = config.margin self.gama = config.gama self.cluster_layer = config.cluster_layer self.pool_mode = config.pool_mode self.drop = nn.Dropout(config.dropout_rate) def KLloss(self, probs1, probs2): loss = nn.KLDivLoss() log_probs1 = F.log_softmax(probs1, 1) probs2 = F.softmax(probs2, 1) return loss(log_probs1, probs2) def get_features(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): sequence_output, _, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_output = sequence_output.view(-1, self.hidden_size) start_positions = start_positions.view(-1) end_positions = end_positions.view(-1) start_pos = np.argwhere(start_positions.cpu().numpy() == 1) end_pos = np.argwhere(end_positions.cpu().numpy() == 1) start_pos = np.reshape(start_pos, (len(start_pos))).tolist() end_pos = np.reshape(end_pos, (len(end_pos))).tolist() features = [] for i, s in enumerate(start_pos): if i >= len(end_pos): continue e = end_pos[i] if len(features) == 0: features = sequence_output[s:e + 1] if self.pool_mode == "sum": features = torch.sum(features, dim=0, keepdim=True) elif self.pool_mode == "avg": features = torch.mean(features, dim=0, keepdim=True) elif self.pool_mode == "max": features = features.transpose(0, 1).unsqueeze(0) features = F.max_pool1d( input=features, kernel_size=features.size(2)).transpose(1, 2).squeeze(0) else: aux = sequence_output[s:e + 1] if self.pool_mode == "sum": aux = torch.sum(aux, dim=0, keepdim=True) elif self.pool_mode == "avg": aux = torch.mean(aux, dim=0, keepdim=True) elif self.pool_mode == "max": aux = aux.transpose(0, 1).unsqueeze(0) aux = F.max_pool1d(input=aux, kernel_size=aux.size(2)).transpose( 1, 2).squeeze(0) features = torch.cat((features, aux), 0) #features = self.cluster_outputs(features) return features def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, span_positions=None, input_truth=None, cluster_var=None): sequence_output, _, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) #sequence_output = self.dropout(sequence_output.view(-1, self.hidden_size)) # start_logits = self.start_outputs(sequence_output) end_logits = self.end_outputs(sequence_output) sequence_output = sequence_output.view(-1, self.hidden_size) if start_positions is not None and end_positions is not None: loss_fct = CrossEntropyLoss() start_positions = start_positions.view(-1).long() end_positions = end_positions.view(-1).long() #ner_loss start_loss = loss_fct(start_logits.view(-1, 2), start_positions) end_loss = loss_fct(end_logits.view(-1, 2), end_positions) #total_loss = start_loss + end_loss + span_loss total_loss = (start_loss + end_loss) / 2 if input_truth is not None: #cluster_loss loss_fct_cluster = CrossEntropyLoss(cluster_var) start_pos = np.argwhere(start_positions.cpu().numpy() == 1) end_pos = np.argwhere(end_positions.cpu().numpy() == 1) start_pos = np.reshape(start_pos, (len(start_pos))).tolist() end_pos = np.reshape(end_pos, (len(end_pos))).tolist() features = [] for i, s in enumerate(start_pos): if i >= len(end_pos): continue e = end_pos[i] if i == 0: features = sequence_output[s:e + 1] if self.pool_mode == "sum": features = torch.sum(features, dim=0, keepdim=True) elif self.pool_mode == "avg": features = torch.mean(features, dim=0, keepdim=True) elif self.pool_mode == "max": features = features.transpose(0, 1).unsqueeze(0) features = F.max_pool1d( input=features, kernel_size=features.size(2)).transpose( 1, 2).squeeze(0) else: aux = sequence_output[s:e + 1] if self.pool_mode == "sum": aux = torch.sum(aux, dim=0, keepdim=True) elif self.pool_mode == "avg": aux = torch.mean(aux, dim=0, keepdim=True) elif self.pool_mode == "max": aux = aux.transpose(0, 1).unsqueeze(0) aux = F.max_pool1d( input=aux, kernel_size=aux.size(2)).transpose( 1, 2).squeeze(0) features = torch.cat((features, aux), 0) if len(features) == 0: return total_loss features = self.drop(features) prob = self.cluster_classify(features) CEloss1 = loss_fct_cluster(prob, input_truth[:len(prob)]) #CEloss2=loss_fct(prob_C, input_truth[:len(prob_C)]) #KL=self.KLloss(prob, prob_C) #cluster_loss=CEloss1+CEloss2+KL #cluster_loss = loss_fct_cluster(cluster, input_truth[:len(cluster)]) #print("total_loss: ",total_loss) #print("cluster_loss: ", cluster_loss) return total_loss + self.gama * CEloss1 else: return total_loss else: span_logits = torch.ones(start_logits.size(0), start_logits.size(1), start_logits.size(1)).cuda() return start_logits, end_logits, span_logits
class BertQueryNER(nn.Module): def __init__(self, config, train_steps=1200000): super(BertQueryNER, self).__init__() bert_config = BertConfig.from_dict(config.bert_config.to_dict()) self.bert = BertModel(bert_config) self.start_outputs = SingleNonLinearClassifier(config.hidden_size, 2, config.dropout) self.end_outputs = SingleNonLinearClassifier(config.hidden_size, 2, config.dropout) self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.dropout) self.hidden_size = config.hidden_size self.bert = self.bert.from_pretrained(config.bert_model) self.train_steps = train_steps self.loss_wb = config.weight_start self.loss_we = config.weight_end self.loss_ws = config.weight_span self.device = torch.device("cuda") self.loss_type = config.loss_type if "dynamic_wce" in self.loss_type: start_sig = torch.empty(1) end_sig = torch.empty(1) span_sig = torch.empty(1) # test different init scale self._start_loss_sig = nn.init.normal_(start_sig, ).to(self.device) self._end_loss_sig = nn.init.normal_(end_sig, ).to(self.device) self._span_loss_sig = nn.init.normal_(span_sig, ).to(self.device) def update_loss_ratio(self, current_train_step=None, decay_step=5000, lower_bound_weight=0.6, upper_bound_weight=1.5, decay_base=3.0, increase_base=1.5): if current_train_step is None: return if current_train_step > decay_step: loss_wb = self.loss_wb * (decay_base** -(current_train_step / self.train_steps)) loss_we = self.loss_we * (decay_base** -(current_train_step / self.train_steps)) self.loss_wb = loss_wb if loss_wb > lower_bound_weight else lower_bound_weight self.loss_we = loss_we if loss_we > lower_bound_weight else lower_bound_weight loss_ws = self.loss_ws * (increase_base**(current_train_step / self.train_steps)) self.loss_ws = loss_ws if loss_ws <= upper_bound_weight else upper_bound_weight if current_train_step % 1000 == 0: print( f"*** *** *** >>> update loss weight: {self.loss_wb}, {self.loss_we}, {self.loss_ws}" ) def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, span_positions=None, span_label_mask=None, current_step=None): """ Args: start_positions: (batch x max_len x 1) [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] end_positions: (batch x max_len x 1) [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] span_positions: (batch x max_len x max_len) span_positions[k][i][j] is one of [0, 1], span_positions[k][i][j] represents whether or not from start_pos{i} to end_pos{j} of the K-th sentence in the batch is an entity. """ sequence_output, pooled_output, _ = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_heatmap = sequence_output # batch x seq_len x hidden batch_size, seq_len, hid_size = sequence_heatmap.size() start_logits = self.start_outputs( sequence_heatmap) # batch x seq_len x 2 end_logits = self.end_outputs(sequence_heatmap) # batch x seq_len x 2 # for every position $i$ in sequence, should concate $j$ to # predict if $i$ and $j$ are start_pos and end_pos for an entity. start_extend = sequence_heatmap.unsqueeze(2).expand( -1, -1, seq_len, -1) end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) # the shape of start_end_concat[0] is : batch x 1 x seq_len x 2*hidden span_matrix = torch.cat([start_extend, end_extend], 3) # batch x seq_len x seq_len x 2*hidden span_logits = self.span_embedding( span_matrix) # batch x seq_len x seq_len x 1 span_logits = torch.squeeze(span_logits) # batch x seq_len x seq_len if start_positions is not None and end_positions is not None: # self.update_loss_ratio(current_train_step=current_step) valid_num = torch.sum(token_type_ids) loss_fct = nn.CrossEntropyLoss(reduction="none") start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1)) start_loss = torch.sum(start_loss * token_type_ids.view(-1)) start_loss = start_loss / valid_num.float() end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1)) end_loss = torch.sum(end_loss * token_type_ids.view(-1)) end_loss = end_loss / valid_num.float() span_loss_fct = nn.BCEWithLogitsLoss(reduction="none") span_loss = span_loss_fct( span_logits.view(batch_size, -1), span_positions.view(batch_size, -1).float()) valid_span_num = torch.sum(span_label_mask) span_loss = torch.sum( span_loss.view(-1) * span_label_mask.view(-1)) span_loss = span_loss / valid_span_num.float() total_loss = self._compute_loss(start_loss, end_loss, span_loss, loss_type=self.loss_type) # total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss return total_loss else: span_scores = torch.sigmoid( span_logits) # batch x seq_len x seq_len start_labels = torch.argmax(start_logits, dim=-1) end_labels = torch.argmax(end_logits, dim=-1) return start_labels, end_labels, span_scores def _compute_loss(self, start_loss, end_loss, span_loss, loss_type="ce"): if loss_type == "ce": total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss return total_loss elif loss_type == "dynamic_wce": b_factor = torch.exp(-self._start_loss_sig) b_loss = b_factor * start_loss + self._start_loss_sig e_factor = torch.exp(-self._end_loss_sig) e_loss = e_factor * end_loss + self._end_loss_sig s_factor = torch.exp(-self._span_loss_sig) s_loss = s_factor * span_loss + self._span_loss_sig total_loss = b_loss + e_loss + s_loss return total_loss else: raise ValueError("Loss Type doesnot exists. ")