class RelationRanking(nn.Module): def __init__(self, word_vocab, rel_vocab, config): super(RelationRanking, self).__init__() self.config = config rel1_vocab, rel2_vocab = rel_vocab self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) self.rel1_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel1_vocab) self.rel2_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel2_vocab) if self.config.rnn_type.lower() == 'gru': self.rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) else: self.rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.dropout = nn.Dropout(p=config.dropout_prob) seq_in_size = config.d_hidden if self.config.birnn: seq_in_size *= 2 self.question_attention = MLPWordSeqAttention(input_size=config.d_rel_embed, seq_size=seq_in_size) self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) self.seq_out = nn.Sequential( self.dropout, nn.Linear(seq_in_size, config.d_rel_embed) ) self.conv = nn.Sequential( nn.Conv2d(1, config.channel_size, (config.conv_kernel_1, config.conv_kernel_2), stride=1, padding=(config.conv_kernel_1//2, config.conv_kernel_2//2)), #channel_in=1, channel_out=8, kernel_size=3*3 nn.ReLU(True)) self.seq_maxlen = config.seq_maxlen + (config.conv_kernel_1 + 1) % 2 self.rel_maxlen = config.rel_maxlen + (config.conv_kernel_2 + 1) % 2 self.pooling = nn.MaxPool2d((config.seq_maxlen, 1), stride=(config.seq_maxlen, 1), padding=0) self.pooling2 = nn.MaxPool2d((1, config.rel_maxlen), stride=(1, config.rel_maxlen), padding=0) self.fc = nn.Sequential( nn.Linear(config.rel_maxlen * config.channel_size, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20, 1)) self.fc1 = nn.Sequential( nn.Linear(config.seq_maxlen * config.channel_size, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20,1)) self.fc2 = nn.Sequential( nn.Linear(4, 1)) def question_encoder(self, inputs): ''' :param inputs: (batch, dim1) ''' batch_size = inputs.size(0) state_shape = self.config.n_cells, batch_size, self.config.d_hidden if self.config.rnn_type.lower() == 'gru': h0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.rnn(inputs, h0) else: h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) # shape of `outputs` - (batch size, sequence length, hidden size X num directions) outputs.contiguous() return outputs def cal_score(self, outputs, seqs_len, rel_embed, pos=None): ''' :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) return: (batch, 1) ''' batch_size = outputs.size(0) if pos: neg_size = pos else: neg_size, batch_size, embed_size = rel_embed.size() seq_len, seq_emb_size = outputs.size()[1:] outputs = outputs.unsqueeze(0).expand(neg_size, batch_size, seq_len, seq_emb_size).contiguous().view(neg_size*batch_size, seq_len, -1) rel_embed = rel_embed.view(neg_size * batch_size, -1) seqs_len = seqs_len.unsqueeze(0).expand(neg_size, batch_size).contiguous().view(neg_size*batch_size) # `weight` - (batch, length) seq_att, weight = self.question_attention.forward(rel_embed, outputs) # `seq_encode` - (batch, hidden size X num directions) seq_encode = self.seq_out(seq_att) # `score` - (batch, 1) or (neg_size * batch, 1) score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) if pos: score = score.unsqueeze(0).expand(neg_size, batch_size, 1) else: score = score.view(neg_size, batch_size, 1) return score def matchPyramid(self, seq, rel, seq_len, rel_len): ''' param: seq: (batch, _seq_len, embed_size) rel: (batch, _rel_len, embed_size) seq_len: (batch,) rel_len: (batch,) return: score: (batch, 1) ''' batch_size = seq.size(0) rel_trans = torch.transpose(rel, 1, 2) # (batch, 1, seq_len, rel_len) seq_norm = torch.sqrt(torch.sum(seq*seq, dim=2, keepdim=True)) rel_norm = torch.sqrt(torch.sum(rel_trans*rel_trans, dim=1, keepdim=True)) cross = torch.bmm(seq/seq_norm, rel_trans/rel_norm).unsqueeze(1) # (batch, channel_size, seq_len, rel_len) conv1 = self.conv(cross) channel_size = conv1.size(1) # (batch, seq_maxlen) # (batch, rel_maxlen) dpool_index1, dpool_index2 = self.dynamic_pooling_index(seq_len, rel_len, self.seq_maxlen, self.rel_maxlen) dpool_index1 = dpool_index1.unsqueeze(1).unsqueeze(-1).expand(batch_size, channel_size, self.seq_maxlen, self.rel_maxlen) dpool_index2 = dpool_index2.unsqueeze(1).unsqueeze(2).expand_as(dpool_index1) conv1_expand = torch.gather(conv1, 2, dpool_index1) conv1_expand = torch.gather(conv1_expand, 3, dpool_index2) # (batch, channel_size, p_size1, p_size2) pool1 = self.pooling(conv1_expand).view(batch_size, -1) # (batch, 1) out = self.fc(pool1) pool2 = self.pooling2(conv1_expand).view(batch_size, -1) out2 = self.fc1(pool2) return out, out2 def dynamic_pooling_index(self, len1, len2, max_len1, max_len2): def dpool_index_(batch_idx, len1_one, len2_one, max_len1, max_len2): stride1 = 1.0 * max_len1 / len1_one stride2 = 1.0 * max_len2 / len2_one idx1_one = [int(i/stride1) for i in range(max_len1)] idx2_one = [int(i/stride2) for i in range(max_len2)] return idx1_one, idx2_one batch_size = len(len1) index1, index2 = [], [] for i in range(batch_size): idx1_one, idx2_one = dpool_index_(i, len1[i], len2[i], max_len1, max_len2) index1.append(idx1_one) index2.append(idx2_one) index1 = torch.LongTensor(index1) index2 = torch.LongTensor(index2) if self.config.cuda: index1 = index1.cuda() index2 = index2.cuda() return Variable(index1), Variable(index2) def forward(self, batch): # shape of seqs (batch size, sequence length) seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2, pos_rel, pos_rel_len, neg_rel, neg_rel_len = batch # shape (batch_size, sequence length, dimension of embedding) inputs = self.word_embed.forward(seqs) outputs = self.question_encoder(inputs) # shape (batch_size, dimension of rel embedding) pos_rel1_embed = self.rel1_embed.word_lookup_table(pos_rel1) pos_rel2_embed = self.rel2_embed.word_lookup_table(pos_rel2) pos_rel1_embed = self.dropout(pos_rel1_embed) pos_rel2_embed = self.dropout(pos_rel2_embed) # shape (neg_size, batch_size, dimension of rel embedding) neg_rel1_embed = self.rel1_embed.word_lookup_table(neg_rel1) neg_rel2_embed = self.rel2_embed.word_lookup_table(neg_rel2) neg_rel1_embed = self.dropout(neg_rel1_embed) neg_rel2_embed = self.dropout(neg_rel2_embed) neg_size, batch, neg_len = neg_rel.size() # shape of `score` - (neg_size, batch_size, 1) pos_score1 = self.cal_score(outputs, seq_len, pos_rel1_embed, neg_size) pos_score2 = self.cal_score(outputs, seq_len, pos_rel2_embed, neg_size) neg_score1 = self.cal_score(outputs, seq_len, neg_rel1_embed) neg_score2 = self.cal_score(outputs, seq_len, neg_rel2_embed) # (batch, len, emb_size) pos_embed = self.word_embed.forward(pos_rel) # (batch, 20) pos_score3, pos_score4 = self.matchPyramid(inputs, pos_embed, seq_len, pos_rel_len) # (neg_size, batch, 20) pos_score3 = pos_score3.unsqueeze(0).expand(neg_size, batch, pos_score3.size(1)) pos_score4 = pos_score4.unsqueeze(0).expand(neg_size, batch, pos_score4.size(1)) # (neg_size*batch, len, emb_size) neg_embed = self.word_embed.forward(neg_rel.view(-1, neg_len)) seqs_embed = inputs.unsqueeze(0).expand(neg_size, batch, inputs.size(1), inputs.size(2)).contiguous().view(-1, inputs.size(1), inputs.size(2)) # (neg_size*batch,) neg_rel_len = neg_rel_len.view(-1) seq_len = seq_len.unsqueeze(0).expand(neg_size, batch).contiguous().view(-1) # (neg_size*batch, 20) neg_score3, neg_score4 = self.matchPyramid(seqs_embed, neg_embed, seq_len, neg_rel_len) # (neg_size, batch, 20) neg_score3 = neg_score3.view(neg_size, batch, neg_score3.size(1)) neg_score4 = neg_score4.view(neg_size, batch, neg_score4.size(1)) pos_concat = torch.cat((pos_score1, pos_score2, pos_score3, pos_score4), 2) neg_concat = torch.cat((neg_score1, neg_score2, neg_score3, neg_score4), 2) pos_score = self.fc2(pos_concat).squeeze(-1) neg_score = self.fc2(neg_concat).squeeze(-1) return pos_score, neg_score
class RelationRanking(nn.Module): def __init__(self, word_vocab, rel_vocab, config): super(RelationRanking, self).__init__() self.config = config rel1_vocab, rel2_vocab = rel_vocab self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) self.rel1_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel1_vocab) self.rel2_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel2_vocab) # print(self.rel_embed.word_lookup_table.weight.data) #rel_embed的初始化待改 rel_embed.lookup_table.weight.data.normal_(0, 0.1) if self.config.rnn_type.lower() == 'gru': self.rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) else: self.rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.dropout = nn.Dropout(p=config.dropout_prob) seq_in_size = config.d_hidden if self.config.birnn: seq_in_size *= 2 self.question_attention = MLPWordSeqAttention(input_size=config.d_rel_embed, seq_size=seq_in_size) self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) self.seq_out = nn.Sequential( # nn.BatchNorm1d(seq_in_size), self.dropout, nn.Linear(seq_in_size, config.d_rel_embed) ) self.conv = nn.Sequential( nn.Conv2d(1, config.channel_size, (config.conv_kernel_1, config.conv_kernel_2), stride=1, padding=(config.conv_kernel_1//2, config.conv_kernel_2//2)), #channel_in=1, channel_out=8, kernel_size=3*3 nn.ReLU(True)) self.seq_maxlen = config.seq_maxlen + (config.conv_kernel_1 + 1) % 2 self.rel_maxlen = config.rel_maxlen + (config.conv_kernel_2 + 1) % 2 p_size1 = self.seq_maxlen // config.pool_kernel_1 p_size2 = self.rel_maxlen // config.pool_kernel_2 self.pooling = nn.MaxPool2d((config.pool_kernel_1, config.pool_kernel_2), stride=(config.pool_kernel_1, config.pool_kernel_2), padding=0) self.fc = nn.Sequential( nn.Linear(p_size1*p_size2*config.channel_size, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20, 1), nn.Sigmoid()) def question_encoder(self, inputs): ''' :param inputs: (batch, dim1) ''' batch_size = inputs.size(0) state_shape = self.config.n_cells, batch_size, self.config.d_hidden if self.config.rnn_type.lower() == 'gru': h0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.rnn(inputs, h0) else: h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) outputs.contiguous() # shape of `outputs` - (batch size, sequence length, hidden size X num directions) # shape of `encoder` - (batch size, hidden size X num directions) # encoder = ht[-1] if not self.config.birnn else ht[-2:].transpose(0,1).contiguous().view(batch_size, -1) # seq_encode = self.seq_out(encoder) return outputs def cal_score(self, outputs, seqs_len, rel_embed, pos=None): ''' :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) return: (batch, 1) ''' batch_size = outputs.size(0) if pos: neg_size = pos else: # neg的要扩展 neg_size, batch_size, embed_size = rel_embed.size() seq_len, seq_emb_size = outputs.size()[1:] outputs = outputs.unsqueeze(0).expand(neg_size, batch_size, seq_len, seq_emb_size).contiguous().view(neg_size*batch_size, seq_len, -1) rel_embed = rel_embed.view(neg_size * batch_size, -1) seqs_len = seqs_len.unsqueeze(0).expand(neg_size, batch_size).contiguous().view(neg_size*batch_size) # `seq_encode` - (batch, hidden size X num directions) # `weight` - (batch, length) # seq_att, weight = self.question_attention.forward(rel_embed, outputs, seqs_len) seq_att, weight = self.question_attention.forward(rel_embed, outputs) # if pos: # print('weight:', weight) # seq_encode = self.dropout(seq_att) seq_encode = self.seq_out(seq_att) # `score` - (batch, 1) or (neg_size * batch, 1) # score = self.bilinear(seq_encode, rel_embed) score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) ''' dot = torch.sum(seq_encode * rel_embed, 1, keepdim=True) dis = seq_encode - rel_embed euclidean = torch.sqrt(torch.sum(dis * dis, 1, keepdim=True)) score = (1/(1+euclidean)) * (1/1+torch.exp(-(dot+1))) ''' if pos: # pos要把结果扩展 score = score.squeeze(1).unsqueeze(0).expand(neg_size, batch_size) else: score = score.view(neg_size, batch_size) return score def matchPyramid(self, seq, rel, seq_len, rel_len): ''' param: seq: (batch, _seq_len, embed_size) rel: (batch, _rel_len, embed_size) seq_len: (batch,) rel_len: (batch,) return: score: (batch, 1) ''' batch_size = seq.size(0) rel_trans = torch.transpose(rel, 1, 2) # (batch, 1, seq_len, rel_len) # cross = torch.bmm(seq, rel_trans).unsqueeze(1) # 将内积改为cos 相似度 seq_norm = torch.sqrt(torch.sum(seq*seq, dim=2, keepdim=True)) rel_norm = torch.sqrt(torch.sum(rel_trans*rel_trans, dim=1, keepdim=True)) cross = torch.bmm(seq/seq_norm, rel_trans/rel_norm).unsqueeze(1) # print('cross: ', cross.size()) # print(cross.squeeze(1).squeeze(0)) # (batch, channel_size, seq_len, rel_len) conv1 = self.conv(cross) channel_size = conv1.size(1) # print('conv: ', conv1.size()) # (batch, seq_maxlen) # (batch, rel_maxlen) dpool_index1, dpool_index2 = self.dynamic_pooling_index(seq_len, rel_len, self.seq_maxlen, self.rel_maxlen) dpool_index1 = dpool_index1.unsqueeze(1).unsqueeze(-1).expand(batch_size, channel_size, self.seq_maxlen, self.rel_maxlen) dpool_index2 = dpool_index2.unsqueeze(1).unsqueeze(2).expand_as(dpool_index1) # print('d1: ', dpool_index1.size()) # print('d2: ', dpool_index2.size()) conv1_expand = torch.gather(conv1, 2, dpool_index1) conv1_expand = torch.gather(conv1_expand, 3, dpool_index2) # print(conv1_expand.size()) # (batch, channel_size, p_size1, p_size2) pool1 = self.pooling(conv1_expand).view(batch_size, -1) # print('pool: ', pool1.size()) # (batch, 1) out = self.fc(pool1) return out def dynamic_pooling_index(self, len1, len2, max_len1, max_len2): def dpool_index_(batch_idx, len1_one, len2_one, max_len1, max_len2): stride1 = 1.0 * max_len1 / len1_one stride2 = 1.0 * max_len2 / len2_one idx1_one = [int(i/stride1) for i in range(max_len1)] idx2_one = [int(i/stride2) for i in range(max_len2)] # mesh1, mesh2 = np.meshgrid(idx1_one, idx2_one) # index_one = np.transpose(np.stack([np.ones(mesh1.shape) * batch_idx, mesh1, mesh2]), (2,1,0)) return idx1_one, idx2_one batch_size = len(len1) index1, index2 = [], [] for i in range(batch_size): idx1_one, idx2_one = dpool_index_(i, len1[i], len2[i], max_len1, max_len2) index1.append(idx1_one) index2.append(idx2_one) # print(index1) # print(index2) index1 = torch.LongTensor(index1) index2 = torch.LongTensor(index2) if self.config.cuda: index1 = index1.cuda() index2 = index2.cuda() return Variable(index1), Variable(index2) def forward(self, batch): # shape of seqs (batch size, sequence length) seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2, pos_rel, pos_rel_len, neg_rel, neg_rel_len = batch # print('seqs:', seqs) # shape (batch_size, sequence length, dimension of embedding) inputs = self.word_embed.forward(seqs) outputs = self.question_encoder(inputs) # shape (batch_size, dimension of rel embedding) pos_rel1_embed = self.rel1_embed.word_lookup_table(pos_rel1) pos_rel2_embed = self.rel2_embed.word_lookup_table(pos_rel2) pos_rel1_embed = self.dropout(pos_rel1_embed) pos_rel2_embed = self.dropout(pos_rel2_embed) # shape (neg_size, batch_size, dimension of rel embedding) neg_rel1_embed = self.rel1_embed.word_lookup_table(neg_rel1) neg_rel2_embed = self.rel2_embed.word_lookup_table(neg_rel2) neg_rel1_embed = self.dropout(neg_rel1_embed) neg_rel2_embed = self.dropout(neg_rel2_embed) neg_size, batch, neg_len = neg_rel.size() # shape of `score` - (neg_size, batch_size) pos_score1 = self.cal_score(outputs, seq_len, pos_rel1_embed, neg_size) pos_score2 = self.cal_score(outputs, seq_len, pos_rel2_embed, neg_size) neg_score1 = self.cal_score(outputs, seq_len, neg_rel1_embed) neg_score2 = self.cal_score(outputs, seq_len, neg_rel2_embed) # (batch, len, emb_size) pos_embed = self.word_embed.forward(pos_rel) # (batch, 1) pos_score3 = self.matchPyramid(inputs, pos_embed, seq_len, pos_rel_len) # (neg_size, batch) pos_score3 = pos_score3.squeeze(-1).unsqueeze(0).expand(neg_size, batch) # (neg_size*batch, len, emb_size) neg_embed = self.word_embed.forward(neg_rel.view(-1, neg_len)) seqs_embed = inputs.unsqueeze(0).expand(neg_size, batch, inputs.size(1), inputs.size(2)).contiguous().view(-1, inputs.size(1), inputs.size(2)) # (neg_size*batch,) neg_rel_len = neg_rel_len.view(-1) seq_len = seq_len.unsqueeze(0).expand(neg_size, batch).contiguous().view(-1) # (neg_size*batch, 1) neg_score3 = self.matchPyramid(seqs_embed, neg_embed, seq_len, neg_rel_len) # (neg_size, batch) neg_score3 = neg_score3.squeeze(-1).view(neg_size, batch) return pos_score1+pos_score2+pos_score3, neg_score1+neg_score2+neg_score3
class ZYK(nn.Module): def __init__(self, word_vocab, rel_vocab, config): super(ZYK, self).__init__() self.config = config self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) self.rel_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel_vocab) if self.config.rnn_type.lower() == 'gru': self.question_encoder_rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.relation_encoder_rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_rel_embed, num_layers=1) else: self.question_encoder_rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.relation_encoder_rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_rel_embed, num_layers=1) self.dropout = nn.Dropout(p=config.dropout_prob) seq_in_size = config.d_hidden*2 if self.config.birnn else config.d_hidden self.question_attention = MLPWordSeqAttention(input_size=config.d_rel_embed, seq_size=seq_in_size) self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) self.seq_out = nn.Sequential( self.dropout, nn.Linear(seq_in_size, config.d_rel_embed) ) self.conv = nn.Sequential( nn.Conv2d(1, config.channel_size, (config.conv_kernel_1, config.conv_kernel_2), stride=1, padding=(config.conv_kernel_1//2, config.conv_kernel_2//2)), #channel_in=1, channel_out=8, kernel_size=3*3 nn.ReLU(True)) self.seq_maxlen = config.seq_maxlen + (config.conv_kernel_1 + 1) % 2 self.rel_maxlen = config.rel_maxlen + (config.conv_kernel_2 + 1) % 2 self.pooling = nn.MaxPool2d((config.seq_maxlen, 1), stride=(config.seq_maxlen, 1), padding=0) self.pooling2 = nn.MaxPool2d((1, config.rel_maxlen), stride=(1, config.rel_maxlen), padding=0) self.fc = nn.Sequential( nn.Linear(config.rel_maxlen * config.channel_size, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20, 1)) self.fc1 = nn.Sequential( nn.Linear(config.seq_maxlen * config.channel_size, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20,1)) self.fc2 = nn.Sequential( nn.Linear(3, 1)) # we use rnn to encode the question def question_encoder(self, inputs, _): ''' :param inputs: (batch, max_seq_len, word_dim) ''' batch_size = inputs.size(0) state_shape = self.config.n_cells, batch_size, self.config.d_hidden if self.config.rnn_type.lower() == 'gru': h0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.question_encoder_rnn(inputs, h0) else: h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.question_encoder_rnn(inputs, (h0, c0)) # shape of `outputs` - (batch size, sequence length, hidden size X num directions) outputs.contiguous() return outputs, _ # def question_encoder(self, inputs, inputs_length): # ''' # :param inputs: (batch, max_len, dim) # ''' # batch_size = inputs.size(0) # state_shape = self.config.n_cells, batch_size, self.config.d_hidden # packed = torch.nn.utils.rnn.pack_padded_sequence(inputs, inputs_length) # if self.config.rnn_type.lower() == 'gru': # h0 = Variable(inputs.data.new(*state_shape).zero_()) # outputs, ht = self.question_encoder_rnn(packed, h0) # else: # h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) # outputs, (ht, ct) = self.question_encoder_rnn(packed, (h0, c0)) # outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # # shape of `outputs` - (batch size, sequence length, hidden size X num directions) # return outputs, ht def relation_encoder(self, inputs, inputs_length, h0): ''' :param inputs: (batch, max_len, dim) hidden: (batch, hidden_dim) ''' device, batch_size = inputs.device, inputs.size(0) inputs = inputs.transpose(0, 1).contiguous() input_length_sorted = sorted(inputs_length, reverse=True) sort_index = np.argsort(-np.array(inputs_length)).tolist() input_sorted = Variable(torch.zeros(inputs.size())).to(device) for b in range(batch_size): input_sorted[:, b, :] = inputs[:, sort_index[b], :] state_shape = 1, batch_size, self.config.d_rel_embed packed = torch.nn.utils.rnn.pack_padded_sequence(inputs, input_length_sorted) if self.config.rnn_type.lower() == 'gru': if h0 is not None: h0 = h0.unsqueeze(0) else: h0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.relation_encoder_rnn(packed, h0) else: if h0 is not None: h0 = c0 = h0.unsqueeze(0) else: h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.relation_encoder_rnn(packed, (h0, c0)) outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) outputs_resorted = Variable(torch.zeros(outputs.size())).to(device) hidden_resorted = Variable(torch.zeros(ht.size())).to(device) for b in range(batch_size): outputs_resorted[:, sort_index[b], :] = outputs[:, b, :] hidden_resorted[:, sort_index[b], :] = ht[:, b, :] # shape of `outputs` - (batch size, sequence length, hidden size X num directions) outputs_resorted = outputs_resorted.transpose(0, 1).contiguous() return outputs_resorted, hidden_resorted.squeeze(0) def cal_score(self, outputs, seqs_len, rel_embed, pos=None): ''' :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) return: (batch, 1) ''' batch_size = outputs.size(0) if pos: neg_size = pos else: neg_size, batch_size, embed_size = rel_embed.size() seq_len, seq_emb_size = outputs.size()[1:] outputs = outputs.unsqueeze(0).expand(neg_size, batch_size, seq_len, seq_emb_size).contiguous().view(neg_size*batch_size, seq_len, -1) rel_embed = rel_embed.view(neg_size * batch_size, -1) seqs_len = seqs_len.unsqueeze(0).expand(neg_size, batch_size).contiguous().view(neg_size*batch_size) # `weight` - (batch, length) seq_att, weight = self.question_attention.forward(rel_embed, outputs) # `seq_encode` - (batch, hidden size X num directions) seq_encode = self.seq_out(seq_att) # `score` - (batch, 1) or (neg_size * batch, 1) score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) if pos: score = score.unsqueeze(0).expand(neg_size, batch_size, 1) else: score = score.view(neg_size, batch_size, 1) return score def matchPyramid(self, seq, rel, seq_len, rel_len): ''' param: seq: (batch, _seq_len, embed_size) rel: (batch, _rel_len, embed_size) seq_len: (batch,) rel_len: (batch,) return: score: (batch, 1) ''' batch_size = seq.size(0) rel_trans = torch.transpose(rel, 1, 2) # (batch, 1, seq_len, rel_len) seq_norm = torch.sqrt(torch.sum(seq*seq, dim=2, keepdim=True)) rel_norm = torch.sqrt(torch.sum(rel_trans*rel_trans, dim=1, keepdim=True)) cross = torch.bmm(seq/seq_norm, rel_trans/rel_norm).unsqueeze(1) # (batch, channel_size, seq_len, rel_len) conv1 = self.conv(cross) channel_size = conv1.size(1) # (batch, seq_maxlen) # (batch, rel_maxlen) device = seq.device dpool_index1, dpool_index2 = self.dynamic_pooling_index(seq_len, rel_len, self.seq_maxlen, self.rel_maxlen) dpool_index1 = dpool_index1.unsqueeze(1).unsqueeze(-1).expand(batch_size, channel_size, self.seq_maxlen, self.rel_maxlen).to(device) dpool_index2 = dpool_index2.unsqueeze(1).unsqueeze(2).expand_as(dpool_index1).to(device) conv1_expand = torch.gather(conv1, 2, dpool_index1) conv1_expand = torch.gather(conv1_expand, 3, dpool_index2) # (batch, channel_size, p_size1, p_size2) pool1 = self.pooling(conv1_expand).view(batch_size, -1) # (batch, 1) out = self.fc(pool1) pool2 = self.pooling2(conv1_expand).view(batch_size, -1) out2 = self.fc1(pool2) return out, out2 def dynamic_pooling_index(self, len1, len2, max_len1, max_len2): def dpool_index_(batch_idx, len1_one, len2_one, max_len1, max_len2): # stride1 = 1.0 * max_len1 / len1_one # stride2 = 1.0 * max_len2 / len2_one stride1 = 1.0 * max_len1 / len1_one.item() stride2 = 1.0 * max_len2 / len2_one.item() idx1_one = [int(i/stride1) for i in range(max_len1)] idx2_one = [int(i/stride2) for i in range(max_len2)] return idx1_one, idx2_one batch_size = len(len1) index1, index2 = [], [] for i in range(batch_size): idx1_one, idx2_one = dpool_index_(i, len1[i], len2[i], max_len1, max_len2) index1.append(idx1_one) index2.append(idx2_one) index1 = torch.LongTensor(index1) index2 = torch.LongTensor(index2) return Variable(index1), Variable(index2) def forward(self, batch): """ :param seqs: (batch_size,max_seq_len ), seq_len:(batch_size) pos_rel:(batch_size) pos_rel_word:(batch_size, max_rel_len), pos_rel_word_len:(batch_size) neg_rel:(neg_size, batch_size) neg_rel_word:(neg_size, batch_size, max_rel_len), neg_rel_word_len:(neg_size, batch_size) :return: """ seqs, seq_len, pos_rel, pos_rel_word, pos_rel_word_len, neg_rel, neg_rel_word, neg_rel_word_len = batch seqs = self.word_embed.forward(seqs) # seqs: (batch_size, max_seq_len, d_word_embed) neg_size, batch, max_rel_len = neg_rel_word.size() pos_rel_word = self.word_embed.forward(pos_rel_word) # pos_rel_word: (batch_size, max_rel_len, d_rel_embedding) _, _, d_rel_embed = pos_rel_word.size() neg_rel_word = self.word_embed.forward(neg_rel_word.view(-1, max_rel_len)) # neg_rel_word: (neg_size, batch_size, max_rel_len, d_rel_embedding) neg_rel_word_len = neg_rel_word_len.view(-1) # neg_rel_word_len: (neg_size*batch_size) pos_cnn_score1, pos_cnn_score2 = self.matchPyramid(seqs, pos_rel_word, seq_len, pos_rel_word_len) pos_cnn_score1 = pos_cnn_score1.unsqueeze(0).expand(neg_size, batch, pos_cnn_score1.size(1)) pos_cnn_score2 = pos_cnn_score2.unsqueeze(0).expand(neg_size, batch, pos_cnn_score2.size(1)) seqs_extend = seqs.unsqueeze(0).expand(neg_size, batch, seqs.size(1), seqs.size(2)).contiguous().view(-1, seqs.size(1), seqs.size(2)) seq_len_extend = seq_len.unsqueeze(0).expand(neg_size, batch).contiguous().view(-1) neg_cnn_score1, neg_cnn_score2 = self.matchPyramid(seqs_extend, neg_rel_word, seq_len_extend, neg_rel_word_len) neg_cnn_score1 = neg_cnn_score1.view(neg_size, batch, neg_cnn_score1.size(1)) neg_cnn_score2 = neg_cnn_score2.view(neg_size, batch, neg_cnn_score2.size(1)) seq_encoded, _ = self.question_encoder(seqs, seq_len) pos_rel = self.rel_embed.word_lookup_table(pos_rel) neg_rel = self.rel_embed.word_lookup_table(neg_rel.view(-1)) if hasattr(self.config, "rel_direct") and self.config.rel_direct == True: pos_rel_encoded = pos_rel neg_rel_encoded = neg_rel else: _, pos_rel_encoded = self.relation_encoder(pos_rel_word, pos_rel_word_len, pos_rel) _, neg_rel_encoded = self.relation_encoder(neg_rel_word, neg_rel_word_len, neg_rel) pos_rnn_score = self.cal_score(seq_encoded, seq_len, pos_rel_encoded, neg_size) neg_rnn_score = self.cal_score(seq_encoded, seq_len, neg_rel_encoded.view(neg_size, batch, d_rel_embed)) if hasattr(self.config, "cnn") and self.config.cnn == False: pos_cnn_score1 = pos_cnn_score1.fill_(0) neg_cnn_score1 = neg_cnn_score1.fill_(0) pos_concat = torch.cat((pos_cnn_score1, pos_cnn_score2, pos_rnn_score), 2) neg_concat = torch.cat((neg_cnn_score1, neg_cnn_score2, neg_rnn_score), 2) pos_score = self.fc2(pos_concat).squeeze(-1) neg_score = self.fc2(neg_concat).squeeze(-1) return pos_score, neg_score
class ZYK2(nn.Module): def __init__(self, word_vocab, rel_vocab, config): super(ZYK2, self).__init__() self.config = config self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) self.rel_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel_vocab) if self.config.rnn_type.lower() == 'gru': self.question_encoder_rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) else: self.question_encoder_rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.dropout = nn.Dropout(p=config.dropout_prob) seq_in_size = config.d_hidden*2 if self.config.birnn else config.d_hidden self.question_attention = MLPWordSeqAttention(input_size=config.d_rel_embed, seq_size=seq_in_size) self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) self.seq_out = nn.Sequential( self.dropout, nn.Linear(seq_in_size, config.d_rel_embed) ) # self.conv = nn.Sequential( # nn.Conv2d(1, config.channel_size, (config.conv_kernel_1, config.conv_kernel_2), stride=1, # padding=(config.conv_kernel_1//2, config.conv_kernel_2//2)), #channel_in=1, channel_out=8, kernel_size=3*3 # nn.ReLU(True)) self.seq_maxlen = config.seq_maxlen + (config.conv_kernel_1 + 1) % 2 self.rel_maxlen = config.rel_maxlen + (config.conv_kernel_2 + 1) % 2 self.pooling = nn.MaxPool2d((config.seq_maxlen, 1), stride=(config.seq_maxlen, 1), padding=0) self.pooling2 = nn.MaxPool2d((1, config.rel_maxlen), stride=(1, config.rel_maxlen), padding=0) self.fc = nn.Sequential( nn.Linear(config.rel_maxlen, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20, 1)) self.fc1 = nn.Sequential( nn.Linear(config.seq_maxlen, 20), nn.ReLU(), nn.Dropout(p=config.dropout_prob), nn.Linear(20, 1)) self.fc2 = nn.Sequential( nn.Linear(2, 1)) # we use rnn to encode the question def question_encoder(self, inputs, _): ''' :param inputs: (batch, max_seq_len, word_dim) ''' batch_size = inputs.size(0) state_shape = self.config.n_cells, batch_size, self.config.d_hidden if self.config.rnn_type.lower() == 'gru': h0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.question_encoder_rnn(inputs, h0) else: h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.question_encoder_rnn(inputs, (h0, c0)) # shape of `outputs` - (batch size, sequence length, hidden size X num directions) outputs.contiguous() return outputs, _ def cal_score(self, outputs, seqs_len, rel_embed, pos=None): ''' :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) return: (batch, 1) ''' batch_size = outputs.size(0) if pos: neg_size = pos else: neg_size, batch_size, embed_size = rel_embed.size() seq_len, seq_emb_size = outputs.size()[1:] outputs = outputs.unsqueeze(0).expand(neg_size, batch_size, seq_len, seq_emb_size).contiguous().view(neg_size*batch_size, seq_len, -1) rel_embed = rel_embed.view(neg_size * batch_size, -1) seqs_len = seqs_len.unsqueeze(0).expand(neg_size, batch_size).contiguous().view(neg_size*batch_size) # `weight` - (batch, length) seq_att, weight = self.question_attention.forward(rel_embed, outputs) # `seq_encode` - (batch, hidden size X num directions) seq_encode = self.seq_out(seq_att) # `score` - (batch, 1) or (neg_size * batch, 1) score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) if pos: score = score.unsqueeze(0).expand(neg_size, batch_size, 1) else: score = score.view(neg_size, batch_size, 1) return score def matchPyramid(self, seq, rel, seq_len, rel_len): ''' param: seq: (batch, _seq_len, embed_size) rel: (batch, _rel_len, embed_size) seq_len: (batch,) rel_len: (batch,) return: score: (batch, 1) ''' batch_size = seq.size(0) rel_trans = torch.transpose(rel, 1, 2) # (batch, 1, seq_len, rel_len) seq_norm = torch.sqrt(torch.sum(seq*seq, dim=2, keepdim=True)) rel_norm = torch.sqrt(torch.sum(rel_trans*rel_trans, dim=1, keepdim=True)) cross = torch.bmm(seq/seq_norm, rel_trans/rel_norm).unsqueeze(1) # (batch, channel_size, seq_len, rel_len) conv1 = cross # self.conv(cross) channel_size = conv1.size(1) # (batch, seq_maxlen) # (batch, rel_maxlen) device = seq.device dpool_index1, dpool_index2 = self.dynamic_pooling_index(seq_len, rel_len, self.seq_maxlen, self.rel_maxlen) dpool_index1 = dpool_index1.unsqueeze(1).unsqueeze(-1).expand(batch_size, channel_size, self.seq_maxlen, self.rel_maxlen).to(device) dpool_index2 = dpool_index2.unsqueeze(1).unsqueeze(2).expand_as(dpool_index1).to(device) conv1_expand = torch.gather(conv1, 2, dpool_index1) conv1_expand = torch.gather(conv1_expand, 3, dpool_index2) # (batch, channel_size, p_size1, p_size2) pool1 = self.pooling(conv1_expand).view(batch_size, -1) # (batch, 1) out = self.fc(pool1) pool2 = self.pooling2(conv1_expand).view(batch_size, -1) out2 = self.fc1(pool2) return out, out2 def dynamic_pooling_index(self, len1, len2, max_len1, max_len2): def dpool_index_(batch_idx, len1_one, len2_one, max_len1, max_len2): # stride1 = 1.0 * max_len1 / len1_one # stride2 = 1.0 * max_len2 / len2_one stride1 = 1.0 * max_len1 / len1_one.item() stride2 = 1.0 * max_len2 / len2_one.item() idx1_one = [int(i/stride1) for i in range(max_len1)] idx2_one = [int(i/stride2) for i in range(max_len2)] return idx1_one, idx2_one batch_size = len(len1) index1, index2 = [], [] for i in range(batch_size): idx1_one, idx2_one = dpool_index_(i, len1[i], len2[i], max_len1, max_len2) index1.append(idx1_one) index2.append(idx2_one) index1 = torch.LongTensor(index1) index2 = torch.LongTensor(index2) return Variable(index1), Variable(index2) def forward(self, batch): """ :param seqs: (batch_size,max_seq_len ), seq_len:(batch_size) pos_rel:(batch_size) pos_rel_word:(batch_size, max_rel_len), pos_rel_word_len:(batch_size) neg_rel:(neg_size, batch_size) neg_rel_word:(neg_size, batch_size, max_rel_len), neg_rel_word_len:(neg_size, batch_size) :return: """ seqs, seq_len, pos_rel, pos_rel_word, pos_rel_word_len, neg_rel, neg_rel_word, neg_rel_word_len = batch seqs = self.word_embed.forward(seqs) # seqs: (batch_size, max_seq_len, d_word_embed) neg_size, batch, max_rel_len = neg_rel_word.size() pos_rel_word = self.word_embed.forward(pos_rel_word) # pos_rel_word: (batch_size, max_rel_len, d_rel_embedding) _, _, d_word_embed = pos_rel_word.size() neg_rel_word = self.word_embed.forward(neg_rel_word.view(-1, max_rel_len)) # neg_rel_word: (neg_size, batch_size, max_rel_len, d_rel_embedding) neg_rel_word_len = neg_rel_word_len.view(-1) # neg_rel_word_len: (neg_size*batch_size) pos_cnn_score1, pos_cnn_score2 = self.matchPyramid(seqs, pos_rel_word, seq_len, pos_rel_word_len) pos_cnn_score1 = pos_cnn_score1.unsqueeze(0).expand(neg_size, batch, pos_cnn_score1.size(1)) pos_cnn_score2 = pos_cnn_score2.unsqueeze(0).expand(neg_size, batch, pos_cnn_score2.size(1)) seqs_extend = seqs.unsqueeze(0).expand(neg_size, batch, seqs.size(1), seqs.size(2)).contiguous().view(-1, seqs.size(1), seqs.size(2)) seq_len_extend = seq_len.unsqueeze(0).expand(neg_size, batch).contiguous().view(-1) neg_cnn_score1, neg_cnn_score2 = self.matchPyramid(seqs_extend, neg_rel_word, seq_len_extend, neg_rel_word_len) neg_cnn_score1 = neg_cnn_score1.view(neg_size, batch, neg_cnn_score1.size(1)) # neg_cnn_score2 = neg_cnn_score2.view(neg_size, batch, neg_cnn_score2.size(1)) seq_encoded, _ = self.question_encoder(seqs, seq_len) pos_rel = self.rel_embed.word_lookup_table(pos_rel) _, d_rel_embed = pos_rel.size() neg_rel = self.rel_embed.word_lookup_table(neg_rel.view(-1)) pos_rnn_score = self.cal_score(seq_encoded, seq_len, pos_rel, neg_size) neg_rnn_score = self.cal_score(seq_encoded, seq_len, neg_rel.view(neg_size, batch, d_rel_embed)) pos_concat = torch.cat((pos_cnn_score1, pos_rnn_score), 2) neg_concat = torch.cat((neg_cnn_score1, neg_rnn_score), 2) pos_score = self.fc2(pos_concat).squeeze(-1) neg_score = self.fc2(neg_concat).squeeze(-1) return pos_score, neg_score # import torch.nn as nn # import sys # import torch # from pytorch_pretrained_bert import BertTokenizer, BertForPreTraining # sys.path.append("../tools/") # word_vocab = torch.load("../../data/vocab/vocab.seq_rel.pt") # model = BertForPreTraining.from_pretrained('bert-base-uncased') # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # model.eval() # # # device = "cuda:2" # model.to(device) # vector_size, words = 768, [] # pretrained = torch.zeros(len(word_vocab), vector_size) # for index, word in word_vocab.index2word.items(): # if word in tokenizer.vocab: # tokens_tensor = torch.tensor([tokenizer.convert_tokens_to_ids([word])]).to(device) # encoded_layers, _ = model(tokens_tensor) # pretrained[index] = encoded_layers[0][0].cpu() # else: # nn.init.uniform(pretrained[index], 0, 0)
class RelationRanking(nn.Module): def __init__(self, word_vocab, rel_vocab, config): super(RelationRanking, self).__init__() self.config = config rel1_vocab, rel2_vocab = rel_vocab self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) self.rel1_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel1_vocab) self.rel2_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel2_vocab) # print(self.rel_embed.word_lookup_table.weight.data) #rel_embed的初始化待改 rel_embed.lookup_table.weight.data.normal_(0, 0.1) if self.config.rnn_type.lower() == 'gru': self.rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) else: self.rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dropout_prob, bidirectional=config.birnn, batch_first=True) self.dropout = nn.Dropout(p=config.dropout_prob) seq_in_size = config.d_hidden if self.config.birnn: seq_in_size *= 2 self.question_attention = MLPWordSeqAttention( input_size=config.d_rel_embed, seq_size=seq_in_size) self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) self.seq_out = nn.Sequential( # nn.BatchNorm1d(seq_in_size), self.dropout, nn.Linear(seq_in_size, config.d_rel_embed)) def question_encoder(self, inputs): ''' :param inputs: (batch, dim1) ''' batch_size = inputs.size(0) state_shape = self.config.n_cells, batch_size, self.config.d_hidden if self.config.rnn_type.lower() == 'gru': h0 = autograd.Variable(inputs.data.new(*state_shape).zero_()) outputs, ht = self.rnn(inputs, h0) else: h0 = c0 = autograd.Variable(inputs.data.new(*state_shape).zero_()) outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) outputs.contiguous() # shape of `outputs` - (batch size, sequence length, hidden size X num directions) # shape of `encoder` - (batch size, hidden size X num directions) # encoder = ht[-1] if not self.config.birnn else ht[-2:].transpose(0,1).contiguous().view(batch_size, -1) # seq_encode = self.seq_out(encoder) return outputs def cal_score(self, outputs, seqs_len, rel_embed, pos=None): ''' :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) return: (batch, 1) ''' batch_size = outputs.size(0) if pos: neg_size = pos else: # neg的要扩展 neg_size, batch_size, embed_size = rel_embed.size() seq_len, seq_emb_size = outputs.size()[1:] outputs = outputs.unsqueeze(0).expand( neg_size, batch_size, seq_len, seq_emb_size).contiguous().view(neg_size * batch_size, seq_len, -1) rel_embed = rel_embed.view(neg_size * batch_size, -1) seqs_len = seqs_len.unsqueeze(0).expand( neg_size, batch_size).contiguous().view(neg_size * batch_size) # `seq_encode` - (batch, hidden size X num directions) # `weight` - (batch, length) # seq_att, weight = self.question_attention.forward(rel_embed, outputs, seqs_len) seq_att, weight = self.question_attention.forward(rel_embed, outputs) # if pos: # print('weight:', weight) # seq_encode = self.dropout(seq_att) seq_encode = self.seq_out(seq_att) # `score` - (batch, 1) or (neg_size * batch, 1) # score = self.bilinear(seq_encode, rel_embed) # score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) dot = torch.sum(seq_encode * rel_embed, 1, keepdim=True) dis = seq_encode - rel_embed euclidean = torch.sqrt(torch.sum(dis * dis, 1, keepdim=True)) score = (1 / (1 + euclidean)) * (1 / 1 + torch.exp(-(dot + 1))) if pos: # pos要把结果扩展 score = score.squeeze(1).unsqueeze(0).expand(neg_size, batch_size) else: score = score.view(neg_size, batch_size) return score def forward(self, batch): # shape of seqs (batch size, sequence length) seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2 = batch # print('seqs:', seqs) # shape (batch_size, sequence length, dimension of embedding) inputs = self.word_embed.forward(seqs) outputs = self.question_encoder(inputs) # shape (batch_size, dimension of rel embedding) pos_rel1_embed = self.rel1_embed.word_lookup_table(pos_rel1) pos_rel2_embed = self.rel2_embed.word_lookup_table(pos_rel2) pos_rel1_embed = self.dropout(pos_rel1_embed) pos_rel2_embed = self.dropout(pos_rel2_embed) # shape (neg_size, batch_size, dimension of rel embedding) neg_rel1_embed = self.rel1_embed.word_lookup_table(neg_rel1) neg_rel2_embed = self.rel2_embed.word_lookup_table(neg_rel2) neg_rel1_embed = self.dropout(neg_rel1_embed) neg_rel2_embed = self.dropout(neg_rel2_embed) neg_size = neg_rel1_embed.size(0) # shape of `score` - (neg_size, batch_size) pos_score1 = self.cal_score(outputs, seq_len, pos_rel1_embed, neg_size) pos_score2 = self.cal_score(outputs, seq_len, pos_rel2_embed, neg_size) neg_score1 = self.cal_score(outputs, seq_len, neg_rel1_embed) neg_score2 = self.cal_score(outputs, seq_len, neg_rel2_embed) return pos_score1, pos_score2, neg_score1, neg_score2