class BertSimpleClassifier(nn.Module):
    def __init__(self, bert_pretrained_weights, num_class):

        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_weights)
        self.positional_encoding = PositionalEncoding(input_dim=768)

        # self.linear_doc = nn.Linear(768, 768)
        # self.linear_prompt = nn.Linear(768, 768)

        self.linear_layer = nn.Linear(768 * 2, num_class)
        self.dropout_layer = nn.Dropout(0.5)
        self.criterion = nn.NLLLoss(reduction='sum')

        nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1)
        nn.init.zeros_(self.linear_layer.bias.data)

    def forward(self,
                inputs,
                mask,
                sent_counts,
                sent_lens,
                prompt_inputs,
                prompt_mask,
                prompt_sent_counts,
                prompt_sent_lens,
                label=None):
        """

        :param prompt_sent_lens:
        :param prompt_sent_counts:
        :param prompt_inputs:
        :param prompt_mask:
        :param inputs:  [batch size, max sent count, max sent len]
        :param mask:    [batch size, max sent count, max sent len]
        :param sent_counts: [batch size]
        :param sent_lens: [batch size, max sent count]
        :param label: [batch size]
        :return:
        """
        batch_size = inputs.shape[0]
        max_sent_count = inputs.shape[1]
        max_sent_length = inputs.shape[2]

        inputs = inputs.view(-1, inputs.shape[-1])
        mask = mask.view(-1, mask.shape[-1])

        # [batch size * max sent len, hid size]
        last_hidden_states = self.bert(input_ids=inputs,
                                       attention_mask=mask)[0]
        last_hidden_states = last_hidden_states.view(batch_size,
                                                     max_sent_count,
                                                     max_sent_length, -1)
        last_hidden_states = self.dropout_layer(last_hidden_states)

        prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1])
        prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1])
        prompt_hidden_states = self.bert(input_ids=prompt_inputs,
                                         attention_mask=prompt_mask)[0]
        prompt_hidden_states = self.dropout_layer(prompt_hidden_states)

        docs = []
        lens = []
        for i in range(0, batch_size):
            doc = []
            sent_count = sent_counts[i]
            sent_len = sent_lens[i]

            for j in range(sent_count):
                length = sent_len[j]
                cur_sent = last_hidden_states[i, j, :length, :]
                # print('cur sent shape', cur_sent.shape)
                doc.append(cur_sent)

            # mean for a doc
            doc_vec = torch.cat(doc, dim=0).unsqueeze(0)
            doc_vec = self.positional_encoding.forward(doc_vec)
            doc_vec = torch.mean(doc_vec, dim=1)

            lens.append(doc_vec.shape[0])
            # print(i, 'doc shape', doc_vec.shape)
            docs.append(doc_vec)

        # [batch size, bert embedding dim]
        docs = torch.cat(docs, 0)

        prompt = []
        for j in range(prompt_sent_counts):
            length = prompt_sent_lens[0][j]
            sent = prompt_hidden_states[j, :length, :]
            prompt.append(sent)

        prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0)
        prompt_vec = self.positional_encoding.forward(prompt_vec)
        # mean [1, bert embedding dim]
        prompt_vec = torch.mean(prompt_vec, dim=1)
        # prompt_vec = self.linear_prompt(prompt_vec)

        doc_feature = docs
        prompt_feature = prompt_vec.expand_as(doc_feature)

        feature = torch.cat([doc_feature, prompt_feature], dim=-1)
        log_probs = torch.log_softmax(torch.tanh(self.linear_layer(feature)),
                                      dim=-1)

        # log_probs = self.classifier(docs)
        if label is not None:
            loss = self.criterion(input=log_probs.contiguous().view(
                -1, log_probs.shape[-1]),
                                  target=label.contiguous().view(-1))
        else:
            loss = None

        prediction = torch.max(log_probs, dim=1)[1]
        return {'loss': loss, 'prediction': prediction}
class MixBertRecurrentAttentionRegressor(nn.Module):
    def __init__(self, bert_pretrained_weights):

        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_weights)
        self.positional_encoding = PositionalEncoding(input_dim=768)

        self.linear_layer = nn.Linear(768 + 5 + 300, 1)
        self.dropout_layer = nn.Dropout(0.6)
        self.criterion = nn.MSELoss(reduction='sum')

        self.manual_feature_layer = nn.Linear(27, 5)

        self.prompt_global_attention = GlobalAttention(hid_dim=768,
                                                       key_size=768)
        self.prompt_doc_attention = BahdanauAttention(hid_dim=768,
                                                      key_size=768,
                                                      query_size=768)

        self.segment_encoder = RNNEncoder(embedding_dim=768,
                                          hid_dim=150,
                                          num_layers=1,
                                          dropout_rate=0.5)

        nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1)
        nn.init.zeros_(self.linear_layer.bias.data)

    def forward(self,
                inputs,
                mask,
                sent_counts,
                sent_lens,
                prompt_inputs,
                prompt_mask,
                prompt_sent_counts,
                prompt_sent_lens,
                min_score,
                max_score,
                manual_feature,
                label=None):
        """

        :param manual_feature: [batch size]
        :param max_score: [batch size]
        :param min_score: [batch size]
        :param prompt_sent_lens: [batch size, max sent count]
        :param prompt_sent_counts: [batch size]
        :param prompt_inputs:   [batch size, max sent count, max sent len]
        :param prompt_mask: [batch size, max sent count, max sent len]
        :param inputs:  [batch size, max sent count, max sent len]
        :param mask:    [batch size, max sent count, max sent len]
        :param sent_counts: [batch size]
        :param sent_lens: [batch size, max sent count]
        :param label: [batch size]
        :return:
        """
        batch_size = inputs.shape[0]
        max_sent_count = inputs.shape[1]
        max_sent_length = inputs.shape[2]

        max_prompt_sent_count = prompt_inputs.shape[1]
        max_prompt_sent_length = prompt_inputs.shape[2]

        inputs = inputs.view(-1, inputs.shape[-1])
        mask = mask.view(-1, mask.shape[-1])

        # [batch size * max sent len, hid size]
        last_hidden_states = self.bert(input_ids=inputs,
                                       attention_mask=mask)[0]
        last_hidden_states = last_hidden_states.view(batch_size,
                                                     max_sent_count,
                                                     max_sent_length, -1)
        last_hidden_states = self.dropout_layer(last_hidden_states)

        prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1])
        prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1])
        prompt_hidden_states = self.bert(input_ids=prompt_inputs,
                                         attention_mask=prompt_mask)[0]
        prompt_hidden_states = prompt_hidden_states.view(
            batch_size, max_prompt_sent_count, max_prompt_sent_length, -1)
        prompt_hidden_states = self.dropout_layer(prompt_hidden_states)

        docs = []
        lens = []
        doc_segments = []
        for i in range(0, batch_size):
            doc = []
            doc_segment = []
            sent_count = sent_counts[i]
            sent_len = sent_lens[i]

            for j in range(sent_count):
                length = sent_len[j]
                cur_sent = last_hidden_states[i, j, :length, :]
                mean_cur_sent = torch.mean(cur_sent, dim=0)
                # print('cur sent shape', cur_sent.shape)
                doc.append(cur_sent)
                doc_segment.append(mean_cur_sent.unsqueeze(0))

            # [1, len, hid size]
            doc_vec = torch.cat(doc, dim=0).unsqueeze(0)
            doc_vec = self.positional_encoding.forward(doc_vec)

            lens.append(doc_vec.shape[1])
            # print(i, 'doc shape', doc_vec.shape)
            docs.append(doc_vec)
            doc_segments.append(doc_segment)

        batch_max_len = max(lens)
        for i, doc in enumerate(docs):
            if doc.shape[1] < batch_max_len:
                pd = (0, 0, 0, batch_max_len - doc.shape[1])
                m = nn.ConstantPad2d(pd, 0)
                doc = m(doc)

            docs[i] = doc

        # [batch size, bert embedding dim]
        docs = torch.cat(docs, 0)
        docs_mask = get_mask_from_sequence_lengths(
            torch.tensor(lens), max_length=batch_max_len).to(docs.device)
        # print('lens ', lens)
        # print('docs shape', docs.shape)

        prompt_docs = []
        prompt_lens = []
        for i in range(0, batch_size):
            prompt_doc = []
            prompt_sent_count = prompt_sent_counts[i]
            prompt_sent_len = prompt_sent_lens[i]

            for j in range(prompt_sent_count):
                length = prompt_sent_len[j]
                cur_sent = prompt_hidden_states[i, j, :length, :]
                prompt_doc.append(cur_sent)

            prompt_doc_vec = torch.cat(prompt_doc, dim=0).unsqueeze(0)
            prompt_doc_vec = self.positional_encoding.forward(prompt_doc_vec)

            prompt_lens.append(prompt_doc_vec.shape[1])
            prompt_docs.append(prompt_doc_vec)

        prompt_batch_max_len = max(prompt_lens)
        for i, doc in enumerate(prompt_docs):
            if doc.shape[1] < prompt_batch_max_len:
                pd = (0, 0, 0, prompt_batch_max_len - doc.shape[1])
                m = nn.ConstantPad2d(pd, 0)
                doc = m(doc)

            prompt_docs[i] = doc

        prompt_docs = torch.cat(prompt_docs, 0)
        prompt_attention_mask = get_mask_from_sequence_lengths(
            torch.tensor(prompt_lens),
            max_length=prompt_batch_max_len).to(docs.device)
        # [batch size, max seq len]
        prompt_vec_weights = self.prompt_global_attention(
            prompt_docs, prompt_attention_mask)

        # [batch size, bert hidden size]
        prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1),
                               prompt_docs).squeeze(1)
        # print('prompt len', prompt_len)

        doc_weights = self.prompt_doc_attention(query=prompt_vec,
                                                key=docs,
                                                mask=docs_mask)
        doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1)
        doc_feature = self.dropout_layer(torch.tanh(doc_vec))
        manual_feature = torch.tanh(
            self.manual_feature_layer(self.dropout_layer(manual_feature)))

        # rnn segments encoder
        sorted_index = sorted(range(len(sent_counts)),
                              key=lambda i: sent_counts[i],
                              reverse=True)
        max_count = max_sent_count
        for idx, doc in enumerate(doc_segments):
            for i in range(max_count - len(doc)):
                doc.append(torch.zeros_like(doc[0]))
            doc_segments[idx] = torch.cat(doc, dim=0).unsqueeze(0)
        doc_segments = torch.cat(doc_segments, dim=0)

        sorted_doc_segments = doc_segments[sorted_index]
        sorted_batch_counts = sent_counts[sorted_index]
        final_hidden_states = self.segment_encoder(
            sorted_doc_segments, sorted_batch_counts)['final_hidden_states']
        final_hidden_states[sorted_index] = final_hidden_states
        final_hidden_states = torch.tanh(final_hidden_states)
        final_hidden_states = self.dropout_layer(final_hidden_states)

        # feature = self.dropout_layer(torch.tanh(doc_vec))
        # prompt_feature = self.dropout_layer(torch.tanh(prompt_vec.expand_as(doc_feature)))
        feature = torch.cat([doc_feature, manual_feature, final_hidden_states],
                            dim=-1)

        grade = self.linear_layer(feature)
        if label is not None:
            # print('label ', label)
            # print('min score ', min_score)
            # print('max score ', max_score)
            # grade = grade * (max_score - min_score) + min_score
            label = (label.type_as(grade) - min_score.type_as(grade)) / (
                max_score.type_as(grade) - min_score.type_as(grade))
            loss = self.criterion(
                input=grade.contiguous().view(-1),
                target=label.type_as(grade).contiguous().view(-1))
        else:
            loss = None

        prediction = grade * (max_score.type_as(grade) - min_score.type_as(
            grade)) + min_score.type_as(grade)
        return {'loss': loss, 'prediction': prediction}
class BertGlobalAttentionClassifier(nn.Module):
    def __init__(self, bert_pretrained_weights, num_class):

        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_weights)

        self.positional_encoding = PositionalEncoding(input_dim=768)

        self.linear_layer = nn.Linear(768 * 2 + 5, num_class)
        self.manual_feature_layer = nn.Linear(27, 5)
        self.dropout_layer = nn.Dropout(0.5)
        self.criterion = nn.NLLLoss(reduction='mean')

        self.prompt_global_attention = GlobalAttention(hid_dim=768,
                                                       key_size=768)
        self.doc_global_attention = GlobalAttention(hid_dim=768, key_size=768)

        nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1)
        nn.init.zeros_(self.linear_layer.bias.data)

    def forward(self,
                inputs,
                mask,
                sent_counts,
                sent_lens,
                prompt_inputs,
                prompt_mask,
                prompt_sent_counts,
                prompt_sent_lens,
                manual_feature,
                label=None):
        """

        :param prompt_sent_lens:
        :param prompt_sent_counts:
        :param prompt_inputs:
        :param prompt_mask:
        :param inputs:  [batch size, max sent count, max sent len]
        :param mask:    [batch size, max sent count, max sent len]
        :param sent_counts: [batch size]
        :param sent_lens: [batch size, max sent count]
        :param label: [batch size]
        :return:
        """
        batch_size = inputs.shape[0]
        max_sent_count = inputs.shape[1]
        max_sent_length = inputs.shape[2]

        inputs = inputs.view(-1, inputs.shape[-1])
        mask = mask.view(-1, mask.shape[-1])

        # [batch size * max sent len, hid size]
        last_hidden_states = self.bert(input_ids=inputs,
                                       attention_mask=mask)[0]
        last_hidden_states = last_hidden_states.view(batch_size,
                                                     max_sent_count,
                                                     max_sent_length, -1)

        prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1])
        prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1])
        prompt_hidden_states = self.bert(input_ids=prompt_inputs,
                                         attention_mask=prompt_mask)[0]

        docs = []
        lens = []
        for i in range(0, batch_size):
            doc = []
            sent_count = sent_counts[i]
            sent_len = sent_lens[i]

            for j in range(sent_count):
                length = sent_len[j]
                cur_sent = last_hidden_states[i, j, :length, :]
                # print('cur sent shape', cur_sent.shape)
                doc.append(cur_sent)

            # mean for a doc
            doc_vec = torch.cat(doc, dim=0).unsqueeze(0)
            doc_vec = self.positional_encoding.forward(doc_vec)

            lens.append(doc_vec.shape[1])
            # print(i, 'doc shape', doc_vec.shape)
            docs.append(doc_vec)

        batch_max_len = max(lens)
        for i, doc in enumerate(docs):
            if doc.shape[1] < batch_max_len:
                pd = (0, 0, 0, batch_max_len - doc.shape[1])
                m = nn.ConstantPad2d(pd, 0)
                doc = m(doc)

            docs[i] = doc

        # [batch size, bert embedding dim]
        docs = torch.cat(docs, 0)
        docs_mask = get_mask_from_sequence_lengths(
            torch.tensor(lens), max_length=batch_max_len).to(docs.device)

        prompt = []
        for j in range(prompt_sent_counts):
            length = prompt_sent_lens[0][j]
            sent = prompt_hidden_states[j, :length, :]
            prompt.append(sent)

        prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0)
        prompt_vec = self.positional_encoding.forward(prompt_vec)
        prompt_len = prompt_vec.shape[1]
        prompt_attention_mask = get_mask_from_sequence_lengths(
            torch.tensor([prompt_len]),
            max_length=prompt_len).to(prompt_vec.device)
        # [1, seq len]
        prompt_vec_weights = self.prompt_global_attention(
            prompt_vec, prompt_attention_mask)
        # [1, bert hidden size]
        prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1),
                               prompt_vec).squeeze(1)

        doc_weights = self.doc_global_attention(docs, docs_mask)
        doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1)

        doc_feature = self.dropout_layer(torch.tanh(doc_vec))
        prompt_feature = self.dropout_layer(
            torch.tanh(prompt_vec.expand_as(doc_feature)))
        feature = torch.cat([doc_feature, prompt_feature], dim=-1)

        log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1)

        # log_probs = self.classifier(docs)
        if label is not None:
            loss = self.criterion(input=log_probs.contiguous().view(
                -1, log_probs.shape[-1]),
                                  target=label.contiguous().view(-1))
        else:
            loss = None

        prediction = torch.max(log_probs, dim=1)[1]
        return {'loss': loss, 'prediction': prediction}
Exemplo n.º 4
0
class BertClassifier(nn.Module):
    def __init__(self, bert_pretrained_weights, num_class, kernel_size,
                 kernel_nums):

        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_weights)

        self.positional_encoding = PositionalEncoding(input_dim=768)
        # self.classifier = CNNClassifier(num_class=num_class,
        #                                 input_dim=768,
        #                                 kernel_nums=kernel_nums,
        #                                 kernel_sizes=kernel_size,
        #                                 max_kernel_size=kernel_size[-1])

        # self.essay_feature_extracter = CNNFeatureExtrater(
        #     input_dim=768,
        #     output_dim=300,
        #     kernel_nums=kernel_nums,
        #     kernel_sizes=kernel_size,
        #     max_kernel_size=kernel_size[-1]
        # )
        # self.prompt_feature_extracter = CNNFeatureExtrater(
        #     input_dim=768,
        #     output_dim=300,
        #     kernel_sizes=[2, 4, 8, 16, 32, 64, 128, 256],
        #     kernel_nums=[64, 64, 64, 64, 64, 64, 64, 64],
        #     max_kernel_size=kernel_size[-1]
        # )
        self.linear_layer = nn.Linear(768 * 2, num_class)

        self.dropout_layer = nn.Dropout(0.5)
        self.criterion = nn.NLLLoss(reduction='mean')

    def forward(self,
                inputs,
                mask,
                sent_counts,
                sent_lens,
                prompt_inputs,
                prompt_mask,
                prompt_sent_counts,
                prompt_sent_lens,
                label=None):
        """

        :param prompt_sent_lens:
        :param prompt_sent_counts:
        :param prompt_inputs:
        :param prompt_mask:
        :param inputs:  [batch size, max sent count, max sent len]
        :param mask:    [batch size, max sent count, max sent len]
        :param sent_counts: [batch size]
        :param sent_lens: [batch size, max sent count]
        :param label: [batch size]
        :return:
        """
        batch_size = inputs.shape[0]
        max_sent_count = inputs.shape[1]
        max_sent_length = inputs.shape[2]

        inputs = inputs.view(-1, inputs.shape[-1])
        mask = mask.view(-1, mask.shape[-1])

        # [batch size * max sent len, hid size]
        last_hidden_states = self.bert(input_ids=inputs,
                                       attention_mask=mask)[0]
        last_hidden_states = last_hidden_states.view(batch_size,
                                                     max_sent_count,
                                                     max_sent_length, -1)

        prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1])
        prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1])
        prompt_hidden_states = self.bert(input_ids=prompt_inputs,
                                         attention_mask=prompt_mask)[0]

        docs = []
        lens = []
        for i in range(0, batch_size):
            doc = []
            sent_count = sent_counts[i]
            sent_len = sent_lens[i]

            for j in range(sent_count):
                length = sent_len[j]
                cur_sent = last_hidden_states[i, j, :length, :]
                # print('cur sent shape', cur_sent.shape)
                doc.append(cur_sent)

            doc_vec = torch.cat(doc, dim=0).unsqueeze(0)
            doc_vec = self.positional_encoding.forward(doc_vec)
            doc_vec = torch.mean(doc_vec, dim=1)

            lens.append(doc_vec.shape[0])
            # print(i, 'doc shape', doc_vec.shape)
            docs.append(doc_vec)

        # batch_max_len = max(lens)
        # for i, doc in enumerate(docs):
        #     if doc.shape[0] < batch_max_len:
        #         pd = (0, 0, 0, batch_max_len - doc.shape[0])
        #         m = nn.ConstantPad2d(pd, 0)
        #         doc = m(doc)
        #
        #     docs[i] = doc.unsqueeze(0)

        docs = torch.cat(docs, 0)
        # print(docs.shape)
        # docs = self.positional_encoding.forward(docs)
        # [batch size, num_class]

        prompt = []
        for j in range(prompt_sent_counts):
            length = prompt_sent_lens[0][j]
            sent = prompt_hidden_states[j, :length, :]
            prompt.append(sent)

        prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0)
        prompt_vec = self.positional_encoding.forward(prompt_vec)
        prompt_vec = torch.mean(prompt_vec, dim=1)

        # [batch size, feature size]
        # doc_feature = self.essay_feature_extracter(docs)
        # prompt_feature = self.prompt_feature_extracter(prompt_vec)
        # prompt_feature = prompt_feature.expand_as(doc_feature)

        doc_feature = self.dropout_layer(torch.tanh(docs))
        prompt_feature = self.dropout_layer(
            torch.tanh(prompt_vec.expand_as(doc_feature)))

        feature = torch.cat([doc_feature, prompt_feature], dim=-1)
        log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1)

        # log_probs = self.classifier(docs)
        if label is not None:
            loss = self.criterion(input=log_probs.contiguous().view(
                -1, log_probs.shape[-1]),
                                  target=label.contiguous().view(-1))
        else:
            loss = None

        prediction = torch.max(log_probs, dim=1)[1]
        return {'loss': loss, 'prediction': prediction}