Beispiel #1
0
class BertForABSA(BertPreTrainedModel):
    def __init__(self, config, num_labels=3):
        super(BertForABSA, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.hsum = HSUM(4, config, num_labels)
        self.init_weights()

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        layers = self.bert(input_ids,
                           token_type_ids=token_type_ids,
                           attention_mask=attention_mask,
                           output_hidden_states=True)['hidden_states']
        mask = self.bert.get_extended_attention_mask(attention_mask,
                                                     input_ids.shape,
                                                     device=layers[0].device)
        loss, logits = self.hsum(layers[1:], mask, labels)
        if labels is not None:
            return loss
        else:
            return logits
Beispiel #2
0
class BertFold(nn.Module):
    def __init__(
        self,
        pretrained: bool = True,
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        if pretrained:
            self.bert = BertModel.from_pretrained(
                'Rostlab/prot_bert_bfd',
                gradient_checkpointing=gradient_checkpointing,
            )
        else:
            conf = BertConfig.from_pretrained('Rostlab/prot_bert_bfd')
            self.bert = BertModel(conf)

        # noinspection PyUnresolvedReferences
        dim = self.bert.config.hidden_size

        self.evo_linear = nn.Linear(21, dim)

        self.decoder_dist = PairwiseDistanceDecoder(dim)
        # self.decoder_phi = ElementwiseAngleDecoder(dim, 2)
        # self.decoder_psi = ElementwiseAngleDecoder(dim, 2)

        self.evo_linear.apply(init_weights)
        self.decoder_dist.apply(init_weights)
        # self.decoder_phi.apply(init_weights)
        # self.decoder_psi.apply(init_weights)

        del self.bert.pooler

    def forward(
        self,
        inputs: ProteinNetBatch,
        targets: Optional[BertFoldTargets] = None,
    ) -> BertFoldOutput:
        x_emb = self.bert.embeddings(inputs['input_ids'])
        x_evo = self.evo_linear(inputs['evo'].type_as(x_emb))
        x = x_emb + x_evo
        extended_attention_mask = self.bert.get_extended_attention_mask(
            inputs['attention_mask'],
            inputs['input_ids'].shape,
            inputs['input_ids'].device,
        )
        x = self.bert.encoder.forward(
            x, attention_mask=extended_attention_mask)[0]

        # x = self.bert.forward(
        #     inputs['input_ids'],
        #     attention_mask=inputs['attention_mask'],
        # )[0]
        # x = torch.cat((
        #     x,
        #     inputs['evo'].type_as(x),
        # ), dim=-1)

        targets_dist = None if targets is None else targets.dist
        # targets_phi = None if targets is None else targets.phi
        # targets_psi = None if targets is None else targets.psi

        outs = [
            self.decoder_dist.forward(x, targets_dist),
            # self.decoder_phi.forward(x, targets_phi),
            # self.decoder_psi.forward(x, targets_psi),
        ]

        y_hat = tuple(x.y_hat for x in outs)

        if targets is None:
            return BertFoldOutput(y_hat=y_hat, )

        loss = torch.stack([x.loss for x in outs]).sum()

        # Collect metrics
        with torch.no_grad():
            # Long range MAE metrics
            mae_l8_fn = MAEForSeq(contact_thre=8.)
            results = mae_l8_fn(
                inputs=y_hat[0][targets.dist.indices],
                targets=targets.dist.values,
                indices=targets.dist.indices,
            )
            if len(results) > 0:
                mae_l_8 = (results.mean().detach().item(), len(results))
            else:
                mae_l_8 = (0, 0)

            # Top L/5 precision metrics
            # top_l5_precision_fn = TopLNPrecision(n=5, contact_thre=8.)
            # results = top_l5_precision_fn(
            #     inputs=out_dist.y_hat[targets.dist.indices],
            #     targets=targets.dist.values,
            #     indices=targets.dist.indices,
            #     seq_lens=attention_mask.sum(-1) - 2,
            # )
            # if len(results) > 0:
            #     top_l5_precision = (results.mean().detach().item(), len(results))
            # else:
            #     top_l5_precision = (0, 0)

        return BertFoldOutput(
            y_hat=y_hat,
            loss=loss,
            loss_dist=outs[0].loss_and_cnt,
            # loss_phi=outs[1].loss_and_cnt,
            # loss_psi=outs[2].loss_and_cnt,
            mae_l_8=mae_l_8,
        )
class BertForMWA(BertPreTrainedModel):
    def __init__(self, config, label2ids, device):
        super(BertForMWA, self).__init__(config)
        # config.output_attentions = True
        self.bert = BertModel(config)  #配置Bert模型
        # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation = nn.ReLU()
        self.label_num = len(label2ids)
        self.head_num = 1  #1个注意力头
        '''
        "我在最初实现的时候,确实是使用了多头注意力机制,然而实现出来的效果确实不佳。
        原始的BERT+softmax模型在test数据集上大概有0.943的f1分数(整合了实体词+实体类型的span-level f1)。
        使用了本文的方法,基于多头注意力机制,在几次试验中,均只能得到0.940左右的分数,基本没有提升。
        后来,我逐渐降低了注意力头的个数,从12减到了6,再从6减到了1。
        最终的效果居然是越来越好,把注意力头减到1时,能够得到0.951左右的分数,基本上能有0.8-1的提升。
        这个结果至少可以说明多头注意力机制并非对所有的数据集均适用,还需要多方测试。"
        -- https://mp.weixin.qq.com/s/FBHjAGLQboufB2pTLsPp4A
        '''

        self.mix_lambda = nn.Parameter(torch.tensor(0.5))

        self.linear_q = nn.Linear(config.hidden_size,
                                  config.hidden_size,
                                  bias=True)
        self.linear_k = nn.Linear(config.hidden_size,
                                  config.hidden_size,
                                  bias=True)
        self.linear_v = nn.Linear(config.hidden_size,
                                  config.hidden_size,
                                  bias=True)
        self.linear_o = nn.Linear(config.hidden_size,
                                  config.hidden_size,
                                  bias=True)

        self.linear_q2 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_k2 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_v2 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_o2 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)

        self.linear_q3 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_k3 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_v3 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)
        self.linear_o3 = nn.Linear(config.hidden_size,
                                   config.hidden_size,
                                   bias=True)

        self.ensemble_linear = nn.Linear(config.hidden_size,
                                         config.hidden_size,
                                         bias=True)
        self.ensemble_activation = nn.Tanh()
        self.dropout_1 = nn.Dropout(config.attention_probs_dropout_prob)
        self.dropout_2 = nn.Dropout(config.attention_probs_dropout_prob)
        self.dropout_3 = nn.Dropout(config.attention_probs_dropout_prob)
        self.qa_outputs = nn.Linear(config.hidden_size, len(label2ids))
        self.crf = CRF(tagset_size=len(label2ids),
                       tag_dictionary=label2ids,
                       device=device,
                       is_bert=True)  #既然博文作者提到使用CRF没有什么提升, 为什么还保留了CRF层?
        self.init_weights()  #初始化权重

    def forward(self,
                input_ids,
                word_length_1,
                word_length_2,
                word_length_3,
                word_slice_1,
                word_slice_2,
                word_slice_3,
                token_type_ids=None,
                input_lens=None,
                attention_mask=None,
                labels=None):

        encoded_layers, _ = self.bert(input_ids, token_type_ids,
                                      attention_mask)
        # attention_output =
        extend_mask = self.bert.get_extended_attention_mask(
            attention_mask, input_ids.size(), input_ids.device)
        seg_attention_out = self.MultiHeadSegATT(encoded_layers,
                                                 encoded_layers,
                                                 encoded_layers, self.linear_q,
                                                 self.linear_k, self.linear_v,
                                                 word_length_1, word_slice_1,
                                                 extend_mask, self.dropout_1)

        seg_attention_out2 = self.MultiHeadSegATT(
            encoded_layers, encoded_layers, encoded_layers, self.linear_q2,
            self.linear_k2, self.linear_v2, word_length_2, word_slice_2,
            extend_mask, self.dropout_2)

        seg_attention_out3 = self.MultiHeadSegATT(
            encoded_layers, encoded_layers, encoded_layers, self.linear_q3,
            self.linear_k3, self.linear_v3, word_length_3, word_slice_3,
            extend_mask, self.dropout_3)

        # tricky way to ensemble by character position.
        batch, seqlen, hidden = input_ids.size(0), input_ids.size(
            1), self.config.hidden_size
        sequence_output = torch.autograd.Variable(
            torch.zeros([batch, seqlen, hidden])).to(input_ids.device)
        for i in range(seqlen):
            att1 = self.ensemble_activation(
                self.ensemble_linear(seg_attention_out[:, i, :]))
            att2 = self.ensemble_activation(
                self.ensemble_linear(seg_attention_out3[:, i, :]))
            att3 = self.ensemble_activation(
                self.ensemble_linear(seg_attention_out2[:, i, :]))
            att4 = self.ensemble_activation(
                self.ensemble_linear(encoded_layers[:, i, :]))
            sequence_output[:, i, :] = att1 + att2 + att3 + att4
            # sequence_output[:, i, :] = att1 + att4

        # sequence_output = seg_attention_out2 # For ablation experiment

        logits = self.qa_outputs(sequence_output)
        # start_logits, end_logits = logits.split(1, dim=-1)
        # start_logits = start_logits.squeeze(-1)
        # end_logits = end_logits.squeeze(-1)

        if labels is not None:
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.label_num)[active_loss]
            active_labels = labels.view(-1)[active_loss]
            loss_fct = CrossEntropyLoss(ignore_index=0, reduction="sum")
            loss = loss_fct(active_logits, active_labels)
            pred_ids = torch.argmax(logits, dim=-1)
            # return logits, self.crf.calculate_loss(logits, tag_list=labels, lengths=input_lens)
            return pred_ids, loss
        else:
            pred_ids = torch.argmax(logits, dim=-1)
            return pred_ids

    def MultiHeadSegATT(self, q, k, v, Q, K, V, word_lengths,
                        word_slice_indexs, attention_mask, dropout_obj):
        q, k, v = Q(q), K(k), V(v)
        q = self.activation(q)
        k = self.activation(k)
        v = self.activation(v)

        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(
            v)  # batch_size, head_num, seq_len, sub_dim

        dk = q.size()[-1]
        scores = q @ (k.transpose(-2, -1)) / math.sqrt(
            dk)  # batch_size, head_num, seq_len_row, seq_len_col
        scores = scores + attention_mask
        attention = F.softmax(scores, dim=-1)

        scalar = self.calculate_scale(attention.detach(), word_slice_indexs,
                                      word_lengths)

        y = attention * scalar @ v  # applying aligned attention

        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        y = dropout_obj(y)
        y = self.activation(y)
        return y

    def _reshape_to_batches(self, x):
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim) \
            .permute(0, 2, 1, 3)  # batch_size, head_num, seq_len, sub_dim

    def _reshape_from_batches(self, x):
        batch_size, head_num, seq_len, sub_dim = x.size()
        out_dim = head_num * sub_dim
        return x.permute(0, 2, 1, 3).reshape(batch_size, seq_len, out_dim)

    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features,
            self.head_num,
            self.bias,
            self.activation,
        )

    def calculate_scale(self, att_weights, seg_slice, seg_length):
        batch_size, head_num, seq_len_row, seq_len_col = att_weights.size()
        batch_size = int(batch_size)
        mask = torch.zeros(att_weights.size()).to(seg_length.device)
        # iterate till encounter padding tag, for early stopping and accelerate.
        stop_condition = (seg_length != 0).sum(dim=1)
        '''
        由于需要根据不同分词的子串对不同长度的字符向量集合进行pooling的计算,因此很难实现一个高效的并行化计算,
        简而言之,就是必须使用for循环来处理序列,这样的话,即使用大batch_size,也无法提升GPU的使用效率
        '''

        for batch_idx in range(batch_size):
            if att_weights[batch_idx].nelement() == 0:
                continue
            for s in range(int(stop_condition[batch_idx])):
                token_pos = seg_slice[batch_idx][s]
                token_length = seg_length[batch_idx][s]
                if token_pos > stop_condition[batch_idx]:
                    break
                if bool(token_length > 1):
                    att = att_weights[batch_idx, :, :,
                                      token_pos:token_pos + token_length]
                    if att.nelement() == 0:
                        continue
                    mean = att.mean(-1, keepdim=True)  # .repeat(att.size(0))
                    max = att.max(-1, keepdim=True)[0]
                    # try to make attention more balanced
                    # mean = mean * (att <= mean).float() + att * (att > mean).float()
                    mix = max * self.mix_lambda + mean * (torch.tensor(1).to(
                        seg_length.device) - self.mix_lambda)
                    mask[batch_idx, :, :,
                         token_pos:token_pos + token_length] = mix / att
                else:
                    mask[batch_idx, :, :, token_pos: token_pos + token_length] = \
                        torch.ones([head_num, seq_len_row, token_length])
        return mask
Beispiel #4
0
class BiaffineDependencyS2TQeuryParser(BertPreTrainedModel):
    """
    This dependency parser follows the model of
    [Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)]
    (https://arxiv.org/abs/1611.01734) .
    But We use token-to-token MRC to extract parent and labels
    """
    def __init__(self, config: Union[BertMrcS2TQueryDependencyConfig,
                                     RobertaMrcS2TQueryDependencyConfig]):
        super().__init__(config)

        self.config = config

        num_dep_labels = len(config.dep_tags)
        num_pos_labels = len(config.pos_tags)
        hidden_size = config.additional_layer_dim

        if config.pos_dim > 0:
            self.pos_embedding = nn.Embedding(num_pos_labels, config.pos_dim)
            nn.init.xavier_uniform_(self.pos_embedding.weight)
            if config.additional_layer_type != "lstm" and config.pos_dim + config.hidden_size != hidden_size:
                self.fuse_layer = nn.Linear(
                    config.pos_dim + config.hidden_size, hidden_size)
                nn.init.xavier_uniform_(self.fuse_layer.weight)
                self.fuse_layer.bias.data.zero_()
            else:
                self.fuse_layer = None
        else:
            self.pos_embedding = None

        if isinstance(config, BertMrcS2TQueryDependencyConfig):
            self.bert = BertModel(config)
            self.arch = "bert"
        else:
            self.roberta = RobertaModel(config)
            self.arch = "roberta"
        # self.is_subtree_feedforward = nn.Sequential(
        #     nn.Linear(config.hidden_size, config.hidden_size),
        #     nn.GELU(),
        #     nn.Dropout(config.mrc_dropout),
        #     nn.Linear(config.hidden_size, 1),
        # )

        if config.additional_layer > 0:
            if config.additional_layer_type == "transformer":
                new_config = deepcopy(config)
                new_config.hidden_size = hidden_size
                new_config.num_hidden_layers = config.additional_layer
                new_config.hidden_dropout_prob = new_config.attention_probs_dropout_prob = config.mrc_dropout
                # new_config.attention_probs_dropout_prob = config.biaf_dropout  # todo add to hparams and tune
                self.additional_encoder = BertEncoder(new_config)
                self.additional_encoder.apply(self._init_bert_weights)
            else:
                assert hidden_size % 2 == 0, "Bi-LSTM need an even hidden_size"
                self.additional_encoder = StackedBidirectionalLstmSeq2SeqEncoder(
                    input_size=config.pos_dim + config.hidden_size,
                    hidden_size=hidden_size // 2,
                    num_layers=config.additional_layer,
                    recurrent_dropout_probability=config.mrc_dropout,
                    use_highway=True)

        else:
            self.additional_encoder = None

        self.parent_feedforward = nn.Linear(hidden_size, 1)
        self.parent_tag_feedforward = nn.Linear(hidden_size, num_dep_labels)

        # self.child_feedforward = nn.Linear(hidden_size, 1)
        # self.child_tag_feedforward = nn.Linear(hidden_size, num_dep_labels)

        self._dropout = nn.Dropout(config.mrc_dropout)
        # self._dropout = InputVariationalDropout(config.mrc_dropout)

        # init linear children
        for layer in self.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def _init_bert_weights(self, module):
        """ Initialize the weights. copy from transformers.BertPreTrainedModel"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @overrides
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
        mrc_mask: torch.BoolTensor,
        parent_idxs: torch.LongTensor = None,
        parent_tags: torch.LongTensor = None,
        # is_subtree: torch.BoolTensor = None
    ):
        """  todo implement docstring
        Args:
            token_ids: [batch_size, num_word_pieces]
            type_ids: [batch_size, num_word_pieces]
            offsets: [batch_size, num_words, 2]
            wordpiece_mask: [batch_size, num_word_pieces]
            pos_tags: [batch_size, num_words]
            word_mask: [batch_size, num_words]
            mrc_mask: [batch_size, num_words]
            parent_idxs: [batch_size]
            parent_tags: [batch_size]
            # is_subtree: [batch_size]
        Returns:
            # is_subtree_probs: [batch_size]
            parent_probs: [batch_size, num_word]
            parent_tag_probs: [batch_size, num_words, num_tags]
            # subtree_loss(if is_subtree is not None)
            arc_loss (if parent_idx is not None)
            tag_loss (if parent_idxs and parent_tags are not None)
        """

        cls_embedding, embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._dropout(embedded_text_input)
        cls_embedding = self._dropout(cls_embedding)

        # [bsz]
        # subtree_scores = self.is_subtree_feedforward(cls_embedding).squeeze(-1)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, seq_len, encoding_dim = encoded_text.size()

        # shape (batch_size, sequence_length, tag_classes)
        parent_tag_scores = self.parent_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length)
        parent_scores = self.parent_feedforward(encoded_text).squeeze(-1)

        # mask out impossible positions
        minus_inf = -1e8
        mrc_mask = torch.logical_and(mrc_mask, word_mask)
        parent_scores = parent_scores + (~mrc_mask).float() * minus_inf

        parent_probs = F.softmax(parent_scores, dim=-1)
        parent_tag_probs = F.softmax(parent_tag_scores, dim=-1)

        # output = (torch.sigmoid(subtree_scores), parent_probs, parent_tag_probs)  # todo check if log in dp evaluation
        output = (parent_probs, parent_tag_probs
                  )  # todo check if log in dp evaluation

        # add losses
        # if is_subtree is not None:
        #     subtree_loss = F.binary_cross_entropy_with_logits(subtree_scores, is_subtree.float())
        #     output = output + (subtree_loss, )
        # else:
        is_subtree = torch.ones_like(parent_tags).bool()

        if parent_idxs is not None:
            sample_mask = is_subtree.float()
            # [bsz]
            batch_range_vector = get_range_vector(batch_size,
                                                  get_device_of(encoded_text))
            # [bsz, seq_len]
            parent_logits = F.log_softmax(parent_scores, dim=-1)
            parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs]
            parent_arc_nll = (parent_arc_nll *
                              sample_mask).sum() / (sample_mask.sum() + 1e-8)
            output = output + (parent_arc_nll, )

            if parent_tags is not None:
                parent_tag_nll = F.cross_entropy(
                    parent_tag_scores[batch_range_vector, parent_idxs],
                    parent_tags,
                    reduction="none")
                parent_tag_nll = (parent_tag_nll * sample_mask).sum() / (
                    sample_mask.sum() + 1e-8)
                output = output + (parent_tag_nll, )

        return output

    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore
        """get [CLS] embedding and word-level embedding"""
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embed_model = self.bert if self.arch == "bert" else self.roberta
        embeddings = embed_model(token_ids,
                                 token_type_ids=type_ids,
                                 attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return embeddings[:, 0, :], orig_embeddings