Beispiel #1
0
    def __init__(self, config, hidden_size, output_modules=1):
        super().__init__()
        self.linear_re = nn.Linear(hidden_size * 2, hidden_size)
        self.use_distance = True
        self.use_cross_attention = config.cross_encoder

        if self.use_distance:
            self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)
            dis_size = config.dis_size
        else:
            dis_size = 0

        if self.use_cross_attention:
            self.attention = Attention(hidden_size, hidden_size, hidden_size)

        self.num_output_modules = output_modules
        if output_modules == 1:
            self.bili = PredictionBiLinear(hidden_size + dis_size,
                                           hidden_size + dis_size,
                                           config.relation_num)
        else:
            bili_list = [
                PredictionBiLinear(hidden_size + dis_size,
                                   hidden_size + dis_size, config.relation_num)
                for _ in range(output_modules)
            ]
            self.bili_list = nn.ModuleList(bili_list)
Beispiel #2
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_output_module = self.config.num_output_module
        self.use_entity_type = config.use_ner_emb
        self.mu_activation = config.mu_activation_option
        bert_hidden_size = 768
        hidden_size = config.hidden_size
        if self.num_output_module > 1:
            self.twin_init = config.twin_init
        self.use_cross_attention = config.cross_encoder
        entity_vector_size = config.entity_type_size if self.use_entity_type else 0

        if self.use_entity_type:
            self.ner_emb = nn.Embedding(7,
                                        config.entity_type_size,
                                        padding_idx=0)

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        if not self.config.train_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.linear = nn.Linear(bert_hidden_size, hidden_size)
        context_hidden_size = hidden_size + entity_vector_size

        self.use_distance = True

        if self.use_distance:
            self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)
            vect_size = context_hidden_size + config.dis_size
        else:
            vect_size = context_hidden_size

        # num_head = 4
        # self.attention = nn.MultiheadAttention(vect_size, num_head)
        if self.use_cross_attention:
            self.attention = Attention(context_hidden_size,
                                       context_hidden_size,
                                       context_hidden_size)

        # '''
        if self.num_output_module == 1:
            self.bili = PredictionBiLinear(vect_size, vect_size,
                                           config.relation_num)
        else:
            # self.bili_multi = PredictionBiLinearMulti(self.num_output_module, vect_size, vect_size, config.relation_num)
            bili_list = [
                PredictionBiLinear(vect_size, vect_size, config.relation_num)
                for _ in range(self.num_output_module)
            ]
            if self.twin_init:
                bili_list[1].load_state_dict(bili_list[0].state_dict())
            self.bili_list = nn.ModuleList(bili_list)
Beispiel #3
0
    def __init__(self,
                 config,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 output_modules=1):
        super().__init__()
        self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding = in_channels, out_channels, kernel_size, stride, padding
        self.use_cross_attention = config.cross_encoder
        self.cnn_1 = nn.Conv1d(self.in_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.cnn_2 = nn.Conv1d(self.out_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.cnn_3 = nn.Conv1d(self.out_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.max_pooling = nn.MaxPool1d(self.kernel_size,
                                        stride=self.stride,
                                        padding=self.padding)
        self.relu = nn.ReLU()

        self.dropout = nn.Dropout(config.cnn_drop_prob)
        self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)

        if self.use_cross_attention:
            hidden_size = self.out_channels
            self.attention = Attention(hidden_size, hidden_size, hidden_size)

        self.num_output_modules = output_modules
        if output_modules == 1:
            self.bili = PredictionBiLinear(self.out_channels + config.dis_size,
                                           self.out_channels + config.dis_size,
                                           config.relation_num)
        else:
            bili_list = [
                PredictionBiLinear(self.out_channels + config.dis_size,
                                   self.out_channels + config.dis_size,
                                   config.relation_num)
                for _ in range(output_modules)
            ]
            self.bili_list = nn.ModuleList(bili_list)
Beispiel #4
0
class CommonNetBiLSTM(nn.Module):
    def __init__(self, config, hidden_size, output_modules=1):
        super().__init__()
        self.linear_re = nn.Linear(hidden_size * 2, hidden_size)
        self.use_distance = True
        self.use_cross_attention = config.cross_encoder

        if self.use_distance:
            self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)
            dis_size = config.dis_size
        else:
            dis_size = 0

        if self.use_cross_attention:
            self.attention = Attention(hidden_size, hidden_size, hidden_size)

        self.num_output_modules = output_modules
        if output_modules == 1:
            self.bili = PredictionBiLinear(hidden_size + dis_size,
                                           hidden_size + dis_size,
                                           config.relation_num)
        else:
            bili_list = [
                PredictionBiLinear(hidden_size + dis_size,
                                   hidden_size + dis_size, config.relation_num)
                for _ in range(output_modules)
            ]
            self.bili_list = nn.ModuleList(bili_list)

    def fix_prediction_bias(self, bias):
        assert self.num_output_modules == 1
        self.bili.fix_bias(bias)

    def add_prediction_bias(self, bias):
        assert self.num_output_modules == 1
        self.bili.add_bias(bias)

    def forward(self, sent, h_mapping, t_mapping, dis_h_2_t, dis_t_2_h):
        context_output = torch.relu(self.linear_re(sent))
        start_re_output = torch.matmul(h_mapping, context_output)
        end_re_output = torch.matmul(t_mapping, context_output)

        if self.use_cross_attention:
            end_re_output, _ = self.attention(start_re_output, context_output,
                                              t_mapping)
            start_re_output, _ = self.attention(end_re_output, context_output,
                                                h_mapping)

        if self.use_distance:
            s_rep = torch.cat(
                [start_re_output, self.dis_embed(dis_h_2_t)], dim=-1)
            t_rep = torch.cat(
                [end_re_output, self.dis_embed(dis_t_2_h)], dim=-1)
        else:
            s_rep = start_re_output
            t_rep = end_re_output

        if self.num_output_modules == 1:
            predict_re = self.bili(s_rep, t_rep)
            return predict_re

        output = [
            self.bili_list[i](s_rep, t_rep)
            for i in range(self.num_output_modules)
        ]
        return output
Beispiel #5
0
class PredictionCNN(nn.Module):
    def __init__(self,
                 config,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 output_modules=1):
        super().__init__()
        self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding = in_channels, out_channels, kernel_size, stride, padding
        self.use_cross_attention = config.cross_encoder
        self.cnn_1 = nn.Conv1d(self.in_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.cnn_2 = nn.Conv1d(self.out_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.cnn_3 = nn.Conv1d(self.out_channels, self.out_channels,
                               self.kernel_size, self.stride, self.padding)
        self.max_pooling = nn.MaxPool1d(self.kernel_size,
                                        stride=self.stride,
                                        padding=self.padding)
        self.relu = nn.ReLU()

        self.dropout = nn.Dropout(config.cnn_drop_prob)
        self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)

        if self.use_cross_attention:
            hidden_size = self.out_channels
            self.attention = Attention(hidden_size, hidden_size, hidden_size)

        self.num_output_modules = output_modules
        if output_modules == 1:
            self.bili = PredictionBiLinear(self.out_channels + config.dis_size,
                                           self.out_channels + config.dis_size,
                                           config.relation_num)
        else:
            bili_list = [
                PredictionBiLinear(self.out_channels + config.dis_size,
                                   self.out_channels + config.dis_size,
                                   config.relation_num)
                for _ in range(output_modules)
            ]
            self.bili_list = nn.ModuleList(bili_list)

    def fix_prediction_bias(self, bias):
        assert self.num_output_modules == 1
        self.bili.fix_bias(bias)

    def add_prediction_bias(self, bias):
        assert self.num_output_modules == 1
        self.bili.add_bias(bias)

    def forward(self, sent, h_mapping, t_mapping, dis_h_2_t, dis_t_2_h):
        x = self.cnn_1(sent)
        x = self.max_pooling(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.cnn_2(x)
        x = self.max_pooling(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.cnn_3(x)
        x = self.max_pooling(x)
        x = self.relu(x)
        x = self.dropout(x)

        context_output = x.permute(0, 2, 1)
        start_re_output = torch.matmul(h_mapping, context_output)
        end_re_output = torch.matmul(t_mapping, context_output)

        if self.use_cross_attention:
            end_re_output, _ = self.attention(start_re_output, context_output,
                                              t_mapping)
            start_re_output, _ = self.attention(end_re_output, context_output,
                                                h_mapping)

        s_rep = torch.cat([start_re_output, self.dis_embed(dis_h_2_t)], dim=-1)
        t_rep = torch.cat([end_re_output, self.dis_embed(dis_t_2_h)], dim=-1)

        if self.num_output_modules == 1:
            predict_re = self.bili(s_rep, t_rep)
            return predict_re

        output = [
            self.bili_list[i](s_rep, t_rep)
            for i in range(self.num_output_modules)
        ]
        return output
Beispiel #6
0
class BERT_RE(nn.Module):
    name = "BERT"

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_output_module = self.config.num_output_module
        self.use_entity_type = config.use_ner_emb
        self.mu_activation = config.mu_activation_option
        bert_hidden_size = 768
        hidden_size = config.hidden_size
        if self.num_output_module > 1:
            self.twin_init = config.twin_init
        self.use_cross_attention = config.cross_encoder
        entity_vector_size = config.entity_type_size if self.use_entity_type else 0

        if self.use_entity_type:
            self.ner_emb = nn.Embedding(7,
                                        config.entity_type_size,
                                        padding_idx=0)

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        if not self.config.train_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.linear = nn.Linear(bert_hidden_size, hidden_size)
        context_hidden_size = hidden_size + entity_vector_size

        self.use_distance = True

        if self.use_distance:
            self.dis_embed = nn.Embedding(20, config.dis_size, padding_idx=10)
            vect_size = context_hidden_size + config.dis_size
        else:
            vect_size = context_hidden_size

        # num_head = 4
        # self.attention = nn.MultiheadAttention(vect_size, num_head)
        if self.use_cross_attention:
            self.attention = Attention(context_hidden_size,
                                       context_hidden_size,
                                       context_hidden_size)

        # '''
        if self.num_output_module == 1:
            self.bili = PredictionBiLinear(vect_size, vect_size,
                                           config.relation_num)
        else:
            # self.bili_multi = PredictionBiLinearMulti(self.num_output_module, vect_size, vect_size, config.relation_num)
            bili_list = [
                PredictionBiLinear(vect_size, vect_size, config.relation_num)
                for _ in range(self.num_output_module)
            ]
            if self.twin_init:
                bili_list[1].load_state_dict(bili_list[0].state_dict())
            self.bili_list = nn.ModuleList(bili_list)

    def fix_prediction_bias(self, bias):
        assert self.num_output_module == 1
        self.bili.fix_bias(bias)

    def add_prediction_bias(self, bias):
        assert self.num_output_module == 1
        self.bili.add_bias(bias)

    @conditional_profiler
    def forward(
        self,
        context_idxs,
        pos,
        context_ner,
        context_char_idxs,
        context_lens,
        h_mapping,
        t_mapping,
        relation_mask,
        dis_h_2_t,
        dis_t_2_h,
        sent_idxs,
        sent_lengths,
        reverse_sent_idxs,
        context_masks,
        context_starts,
    ):

        # sent = torch.cat([sent, context_ch], dim=-1)
        # print(context_idxs.size())
        bert_out = self.bert(context_idxs, attention_mask=context_masks)[0]
        # print('output_1',context_output[0])
        '''
        padded_starts = torch.zeros(bert_out.shape[:-1],dtype = torch.long).cuda().contiguous()
        for i, context_start in enumerate(context_starts):  # repeat Batch times
            temp_cs = context_start.nonzero().squeeze(-1)
            length_temp_cs = temp_cs.shape[0]
            padded_starts[i, :length_temp_cs] = temp_cs  # [L]

        context_output2 = bert_out[padded_starts.unsqueeze(-1)]  # [B,L,1]
        '''

        context_output = [
            layer[starts.nonzero().squeeze(1)]
            for layer, starts in zip(bert_out, context_starts)
        ]
        # print('output_2',context_output[0])
        context_output = pad_sequence(context_output,
                                      batch_first=True,
                                      padding_value=-1)

        # print('output_3',context_output[0])
        # print(context_output.size())
        context_output = torch.nn.functional.pad(
            context_output,
            (0, 0, 0, context_idxs.size(-1) - context_output.size(-2)))

        # context_output = bert_out

        context = self.linear(context_output)
        if self.use_entity_type:
            context = torch.cat([context, self.ner_emb(context_ner)], dim=-1)

        start_re_output = torch.matmul(h_mapping, context)
        end_re_output = torch.matmul(t_mapping, context)

        if self.use_cross_attention:
            tail, _ = self.attention(start_re_output, context, t_mapping)
            head, _ = self.attention(end_re_output, context, h_mapping)
        else:
            head = start_re_output
            tail = end_re_output

        if self.use_distance:
            head_arr = [head]
            tail_arr = [tail]
            if self.use_distance:
                head_arr.append(self.dis_embed(dis_h_2_t))
                tail_arr.append(self.dis_embed(dis_t_2_h))

            s_rep = torch.cat(head_arr, dim=-1)
            t_rep = torch.cat(tail_arr, dim=-1)
        else:
            s_rep = head
            t_rep = tail

        # '''
        if self.num_output_module == 1:
            predict_re_ha_logit = self.bili(s_rep, t_rep)
            return predict_re_ha_logit
        # output_list = self.bili_multi(s_rep, t_rep)
        output_list = [bili(s_rep, t_rep) for bili in self.bili_list]

        return output_list