コード例 #1
0
    def __init__(self, config, train_steps=1200000):
        super(BertQueryNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = SingleNonLinearClassifier(config.hidden_size, 2,
                                                       config.dropout)
        self.end_outputs = SingleNonLinearClassifier(config.hidden_size, 2,
                                                     config.dropout)

        self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2,
                                                       1, config.dropout)
        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.train_steps = train_steps
        self.loss_wb = config.weight_start
        self.loss_we = config.weight_end
        self.loss_ws = config.weight_span

        self.device = torch.device("cuda")
        self.loss_type = config.loss_type
        if "dynamic_wce" in self.loss_type:
            start_sig = torch.empty(1)
            end_sig = torch.empty(1)
            span_sig = torch.empty(1)
            # test different init scale
            self._start_loss_sig = nn.init.normal_(start_sig, ).to(self.device)
            self._end_loss_sig = nn.init.normal_(end_sig, ).to(self.device)
            self._span_loss_sig = nn.init.normal_(span_sig, ).to(self.device)
コード例 #2
0
    def __init__(self, config):
        super(BertMRCNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.cluster_layer = config.cluster_layer
コード例 #3
0
ファイル: bert_mrc.py プロジェクト: chanchimin/BERT_MRC
    def __init__(self, config):
        super(BertQueryNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        # self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.dropout)
        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.loss_wb = config.weight_start
        self.loss_we = config.weight_end
        self.loss_ws = config.weight_span
コード例 #4
0
    def __init__(self, config):
        super(BertMRCNER_CLUSTER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        self.cluster_classify = nn.Linear(config.hidden_size,
                                          config.num_clusters)

        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)

        self.margin = config.margin

        self.gama = config.gama
        self.cluster_layer = config.cluster_layer
        self.pool_mode = config.pool_mode

        self.drop = nn.Dropout(config.dropout_rate)
コード例 #5
0
ファイル: bert_mrc.py プロジェクト: chanchimin/BERT_MRC
class BertQueryNER(nn.Module):
    def __init__(self, config):
        super(BertQueryNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        # self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.dropout)
        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.loss_wb = config.weight_start
        self.loss_we = config.weight_end
        self.loss_ws = config.weight_span

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                start_positions=None,
                end_positions=None,
                span_positions=None):
        """
        Args:
            start_positions: (batch x max_len x 1)
                [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] 
            end_positions: (batch x max_len x 1)
                [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] 
            span_positions: (batch x max_len x max_len) 
                span_positions[k][i][j] is one of [0, 1], 
                span_positions[k][i][j] represents whether or not from start_pos{i} to end_pos{j} of the K-th sentence in the batch is an entity. 
        """

        sequence_output, pooled_output, _ = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=False)

        sequence_heatmap = sequence_output  # batch x seq_len x hidden
        batch_size, seq_len, hid_size = sequence_heatmap.size()

        start_logits = self.start_outputs(
            sequence_heatmap)  # batch x seq_len x 2
        end_logits = self.end_outputs(sequence_heatmap)  # batch x seq_len x 2

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.

        # start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
        # end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # the shape of start_end_concat[0] is : batch x 1 x seq_len x 2*hidden

        # span_matrix = torch.cat([start_extend, end_extend], 3) # batch x seq_len x seq_len x 2*hidden

        # span_logits = self.span_embedding(span_matrix)  # batch x seq_len x seq_len x 1
        # span_logits = torch.squeeze(span_logits)  # batch x seq_len x seq_len

        if start_positions is not None and end_positions is not None:
            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits.view(-1, 2),
                                  start_positions.view(-1))
            end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1))
            # span_loss_fct = nn.BCEWithLogitsLoss()
            # span_loss = span_loss_fct(span_logits.view(batch_size, -1), span_positions.view(batch_size, -1).float())

            # total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss
            total_loss = self.loss_wb * start_loss + self.loss_we * end_loss
            return total_loss
        else:
            # span_logits = torch.sigmoid(span_logits) # batch x seq_len x seq_len
            start_logits = torch.argmax(start_logits, dim=-1)
            end_logits = torch.argmax(end_logits, dim=-1)

            # return start_logits, end_logits, span_logits
            return start_logits, end_logits
コード例 #6
0
class BertMRCNER(nn.Module):
    """
    Desc:
        BERT model for question answering (span_extraction)
        This Module is composed of the BERT model with a linear on top of
        the sequence output that compute start_logits, and end_logits.
    Params:
        config: a BertConfig class instance with the configuration to build a new model.
    Inputs:
        input_ids: torch.LongTensor. of shape [batch_size, sequence_length]
        token_type_ids: an optional torch.LongTensor, [batch_size, sequence_length]
            of the token type [0, 1]. Type 0 corresponds to sentence A, Type 1 corresponds to sentence B.
        attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length]
            with index select [0, 1]. it is a mask to be used if the input sequence length is smaller
            than the max input sequence length in the current batch.
        start_positions: positions of the first token for the labeled span. torch.LongTensor
            of shape [batch_size, seq_len], if current position is start of entity, the value equals to 1.
            else the value equals to 0.
        end_position: position to the last token for the labeled span.
            torch.LongTensor, [batch_size, seq_len]
    Outputs:
        if "start_positions" and "end_positions" are not None
            output the total_loss which is the sum of the CrossEntropy loss
            for the start and end token positions.
        if "start_positon" or "end_positions" is None
    """
    def __init__(self, config):
        super(BertMRCNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.cluster_layer = config.cluster_layer

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                start_positions=None,
                end_positions=None):
        sequence_output, _, _, _ = self.bert(input_ids,
                                             token_type_ids,
                                             attention_mask,
                                             output_all_encoded_layers=False)
        sequence_output = sequence_output.view(-1, self.hidden_size)

        start_logits = self.start_outputs(sequence_output)
        end_logits = self.end_outputs(sequence_output)

        if start_positions is not None and end_positions is not None:
            loss_fct = CrossEntropyLoss()

            start_loss = loss_fct(start_logits.view(-1, 2),
                                  start_positions.view(-1))
            end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1))
            # total_loss = start_loss + end_loss + span_loss
            total_loss = (start_loss + end_loss) / 2
            return total_loss
        else:
            return start_logits, end_logits
コード例 #7
0
class BertMRCNER_CLUSTER(nn.Module):
    """
    Desc:
        BERT model for question answering (span_extraction)
        This Module is composed of the BERT model with a linear on top of 
        the sequence output that compute start_logits, and end_logits. 
    Params:
        config: a BertConfig class instance with the configuration to build a new model. 
    Inputs:
        input_ids: torch.LongTensor. of shape [batch_size, sequence_length]
        token_type_ids: an optional torch.LongTensor, [batch_size, sequence_length]
            of the token type [0, 1]. Type 0 corresponds to sentence A, Type 1 corresponds to sentence B. 
        attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length]
            with index select [0, 1]. it is a mask to be used if the input sequence length is smaller 
            than the max input sequence length in the current batch. 
        start_positions: positions of the first token for the labeled span. torch.LongTensor 
            of shape [batch_size, seq_len], if current position is start of entity, the value equals to 1. 
            else the value equals to 0. 
        end_position: position to the last token for the labeled span. 
            torch.LongTensor, [batch_size, seq_len]
    Outputs:
        if "start_positions" and "end_positions" are not None
            output the total_loss which is the sum of the CrossEntropy loss 
            for the start and end token positions. 
        if "start_positon" or "end_positions" is None 
    """
    def __init__(self, config):
        super(BertMRCNER_CLUSTER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)

        self.cluster_classify = nn.Linear(config.hidden_size,
                                          config.num_clusters)

        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)

        self.margin = config.margin

        self.gama = config.gama
        self.cluster_layer = config.cluster_layer
        self.pool_mode = config.pool_mode

        self.drop = nn.Dropout(config.dropout_rate)

    def KLloss(self, probs1, probs2):
        loss = nn.KLDivLoss()
        log_probs1 = F.log_softmax(probs1, 1)
        probs2 = F.softmax(probs2, 1)
        return loss(log_probs1, probs2)

    def get_features(self,
                     input_ids,
                     token_type_ids=None,
                     attention_mask=None,
                     start_positions=None,
                     end_positions=None):
        sequence_output, _, _ = self.bert(input_ids,
                                          token_type_ids,
                                          attention_mask,
                                          output_all_encoded_layers=False)
        sequence_output = sequence_output.view(-1, self.hidden_size)
        start_positions = start_positions.view(-1)
        end_positions = end_positions.view(-1)

        start_pos = np.argwhere(start_positions.cpu().numpy() == 1)
        end_pos = np.argwhere(end_positions.cpu().numpy() == 1)

        start_pos = np.reshape(start_pos, (len(start_pos))).tolist()
        end_pos = np.reshape(end_pos, (len(end_pos))).tolist()
        features = []
        for i, s in enumerate(start_pos):
            if i >= len(end_pos):
                continue
            e = end_pos[i]
            if len(features) == 0:
                features = sequence_output[s:e + 1]
                if self.pool_mode == "sum":
                    features = torch.sum(features, dim=0, keepdim=True)
                elif self.pool_mode == "avg":
                    features = torch.mean(features, dim=0, keepdim=True)
                elif self.pool_mode == "max":
                    features = features.transpose(0, 1).unsqueeze(0)
                    features = F.max_pool1d(
                        input=features,
                        kernel_size=features.size(2)).transpose(1,
                                                                2).squeeze(0)
            else:
                aux = sequence_output[s:e + 1]
                if self.pool_mode == "sum":
                    aux = torch.sum(aux, dim=0, keepdim=True)
                elif self.pool_mode == "avg":
                    aux = torch.mean(aux, dim=0, keepdim=True)
                elif self.pool_mode == "max":
                    aux = aux.transpose(0, 1).unsqueeze(0)
                    aux = F.max_pool1d(input=aux,
                                       kernel_size=aux.size(2)).transpose(
                                           1, 2).squeeze(0)
                features = torch.cat((features, aux), 0)

        #features = self.cluster_outputs(features)

        return features

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                start_positions=None,
                end_positions=None,
                span_positions=None,
                input_truth=None,
                cluster_var=None):
        sequence_output, _, _ = self.bert(input_ids,
                                          token_type_ids,
                                          attention_mask,
                                          output_all_encoded_layers=False)
        #sequence_output = self.dropout(sequence_output.view(-1, self.hidden_size))
        #

        start_logits = self.start_outputs(sequence_output)
        end_logits = self.end_outputs(sequence_output)

        sequence_output = sequence_output.view(-1, self.hidden_size)

        if start_positions is not None and end_positions is not None:
            loss_fct = CrossEntropyLoss()

            start_positions = start_positions.view(-1).long()
            end_positions = end_positions.view(-1).long()

            #ner_loss
            start_loss = loss_fct(start_logits.view(-1, 2), start_positions)
            end_loss = loss_fct(end_logits.view(-1, 2), end_positions)
            #total_loss = start_loss + end_loss + span_loss
            total_loss = (start_loss + end_loss) / 2

            if input_truth is not None:
                #cluster_loss
                loss_fct_cluster = CrossEntropyLoss(cluster_var)
                start_pos = np.argwhere(start_positions.cpu().numpy() == 1)
                end_pos = np.argwhere(end_positions.cpu().numpy() == 1)
                start_pos = np.reshape(start_pos, (len(start_pos))).tolist()
                end_pos = np.reshape(end_pos, (len(end_pos))).tolist()
                features = []
                for i, s in enumerate(start_pos):
                    if i >= len(end_pos):
                        continue
                    e = end_pos[i]
                    if i == 0:
                        features = sequence_output[s:e + 1]
                        if self.pool_mode == "sum":
                            features = torch.sum(features, dim=0, keepdim=True)
                        elif self.pool_mode == "avg":
                            features = torch.mean(features,
                                                  dim=0,
                                                  keepdim=True)
                        elif self.pool_mode == "max":
                            features = features.transpose(0, 1).unsqueeze(0)
                            features = F.max_pool1d(
                                input=features,
                                kernel_size=features.size(2)).transpose(
                                    1, 2).squeeze(0)
                    else:

                        aux = sequence_output[s:e + 1]
                        if self.pool_mode == "sum":
                            aux = torch.sum(aux, dim=0, keepdim=True)
                        elif self.pool_mode == "avg":
                            aux = torch.mean(aux, dim=0, keepdim=True)
                        elif self.pool_mode == "max":
                            aux = aux.transpose(0, 1).unsqueeze(0)
                            aux = F.max_pool1d(
                                input=aux, kernel_size=aux.size(2)).transpose(
                                    1, 2).squeeze(0)
                        features = torch.cat((features, aux), 0)

                if len(features) == 0:
                    return total_loss
                features = self.drop(features)
                prob = self.cluster_classify(features)
                CEloss1 = loss_fct_cluster(prob, input_truth[:len(prob)])
                #CEloss2=loss_fct(prob_C, input_truth[:len(prob_C)])
                #KL=self.KLloss(prob, prob_C)
                #cluster_loss=CEloss1+CEloss2+KL

                #cluster_loss = loss_fct_cluster(cluster, input_truth[:len(cluster)])
                #print("total_loss:  ",total_loss)
                #print("cluster_loss:    ", cluster_loss)
                return total_loss + self.gama * CEloss1
            else:
                return total_loss
        else:

            span_logits = torch.ones(start_logits.size(0),
                                     start_logits.size(1),
                                     start_logits.size(1)).cuda()
            return start_logits, end_logits, span_logits
コード例 #8
0
class BertQueryNER(nn.Module):
    def __init__(self, config, train_steps=1200000):
        super(BertQueryNER, self).__init__()
        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.start_outputs = SingleNonLinearClassifier(config.hidden_size, 2,
                                                       config.dropout)
        self.end_outputs = SingleNonLinearClassifier(config.hidden_size, 2,
                                                     config.dropout)

        self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2,
                                                       1, config.dropout)
        self.hidden_size = config.hidden_size
        self.bert = self.bert.from_pretrained(config.bert_model)
        self.train_steps = train_steps
        self.loss_wb = config.weight_start
        self.loss_we = config.weight_end
        self.loss_ws = config.weight_span

        self.device = torch.device("cuda")
        self.loss_type = config.loss_type
        if "dynamic_wce" in self.loss_type:
            start_sig = torch.empty(1)
            end_sig = torch.empty(1)
            span_sig = torch.empty(1)
            # test different init scale
            self._start_loss_sig = nn.init.normal_(start_sig, ).to(self.device)
            self._end_loss_sig = nn.init.normal_(end_sig, ).to(self.device)
            self._span_loss_sig = nn.init.normal_(span_sig, ).to(self.device)

    def update_loss_ratio(self,
                          current_train_step=None,
                          decay_step=5000,
                          lower_bound_weight=0.6,
                          upper_bound_weight=1.5,
                          decay_base=3.0,
                          increase_base=1.5):
        if current_train_step is None:
            return
        if current_train_step > decay_step:
            loss_wb = self.loss_wb * (decay_base**
                                      -(current_train_step / self.train_steps))
            loss_we = self.loss_we * (decay_base**
                                      -(current_train_step / self.train_steps))
            self.loss_wb = loss_wb if loss_wb > lower_bound_weight else lower_bound_weight
            self.loss_we = loss_we if loss_we > lower_bound_weight else lower_bound_weight

            loss_ws = self.loss_ws * (increase_base**(current_train_step /
                                                      self.train_steps))
            self.loss_ws = loss_ws if loss_ws <= upper_bound_weight else upper_bound_weight
            if current_train_step % 1000 == 0:
                print(
                    f"*** *** *** >>> update loss weight: {self.loss_wb}, {self.loss_we}, {self.loss_ws}"
                )

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                start_positions=None,
                end_positions=None,
                span_positions=None,
                span_label_mask=None,
                current_step=None):
        """
        Args:
            start_positions: (batch x max_len x 1)
                [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]]
            end_positions: (batch x max_len x 1)
                [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]]
            span_positions: (batch x max_len x max_len)
                span_positions[k][i][j] is one of [0, 1],
                span_positions[k][i][j] represents whether or not from start_pos{i} to end_pos{j} of the K-th sentence in the batch is an entity.
        """

        sequence_output, pooled_output, _ = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=False)

        sequence_heatmap = sequence_output  # batch x seq_len x hidden
        batch_size, seq_len, hid_size = sequence_heatmap.size()

        start_logits = self.start_outputs(
            sequence_heatmap)  # batch x seq_len x 2
        end_logits = self.end_outputs(sequence_heatmap)  # batch x seq_len x 2

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.
        start_extend = sequence_heatmap.unsqueeze(2).expand(
            -1, -1, seq_len, -1)
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # the shape of start_end_concat[0] is : batch x 1 x seq_len x 2*hidden

        span_matrix = torch.cat([start_extend, end_extend],
                                3)  # batch x seq_len x seq_len x 2*hidden

        span_logits = self.span_embedding(
            span_matrix)  # batch x seq_len x seq_len x 1
        span_logits = torch.squeeze(span_logits)  # batch x seq_len x seq_len

        if start_positions is not None and end_positions is not None:
            # self.update_loss_ratio(current_train_step=current_step)
            valid_num = torch.sum(token_type_ids)
            loss_fct = nn.CrossEntropyLoss(reduction="none")
            start_loss = loss_fct(start_logits.view(-1, 2),
                                  start_positions.view(-1))
            start_loss = torch.sum(start_loss * token_type_ids.view(-1))
            start_loss = start_loss / valid_num.float()
            end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1))
            end_loss = torch.sum(end_loss * token_type_ids.view(-1))
            end_loss = end_loss / valid_num.float()
            span_loss_fct = nn.BCEWithLogitsLoss(reduction="none")
            span_loss = span_loss_fct(
                span_logits.view(batch_size, -1),
                span_positions.view(batch_size, -1).float())
            valid_span_num = torch.sum(span_label_mask)
            span_loss = torch.sum(
                span_loss.view(-1) * span_label_mask.view(-1))
            span_loss = span_loss / valid_span_num.float()
            total_loss = self._compute_loss(start_loss,
                                            end_loss,
                                            span_loss,
                                            loss_type=self.loss_type)
            # total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss
            return total_loss
        else:
            span_scores = torch.sigmoid(
                span_logits)  # batch x seq_len x seq_len
            start_labels = torch.argmax(start_logits, dim=-1)
            end_labels = torch.argmax(end_logits, dim=-1)
            return start_labels, end_labels, span_scores

    def _compute_loss(self, start_loss, end_loss, span_loss, loss_type="ce"):
        if loss_type == "ce":
            total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss
            return total_loss
        elif loss_type == "dynamic_wce":
            b_factor = torch.exp(-self._start_loss_sig)
            b_loss = b_factor * start_loss + self._start_loss_sig

            e_factor = torch.exp(-self._end_loss_sig)
            e_loss = e_factor * end_loss + self._end_loss_sig

            s_factor = torch.exp(-self._span_loss_sig)
            s_loss = s_factor * span_loss + self._span_loss_sig
            total_loss = b_loss + e_loss + s_loss
            return total_loss
        else:
            raise ValueError("Loss Type doesnot exists. ")