def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2, l2): # [B, T] p_s1 = self.dropout_layer(layer1_s1) p_s2 = self.dropout_layer(layer1_s2) s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1) s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2) S = self.bidaf.similarity(s1_layer1_out, l1, s2_layer1_out, l2) s1_att, s2_att = self.bidaf.get_both_tile(S, s1_layer1_out, s2_layer1_out) s1_coattentioned = torch.cat([ s1_layer1_out, s1_att, s1_layer1_out - s1_att, s1_layer1_out * s1_att ], dim=2) s2_coattentioned = torch.cat([ s2_layer1_out, s2_att, s2_layer1_out - s2_att, s2_layer1_out * s2_att ], dim=2) p_s1_coattentioned = F.relu(self.projection(s1_coattentioned)) p_s2_coattentioned = F.relu(self.projection(s2_coattentioned)) s1_coatt_features = torch.cat([p_s1_coattentioned, layer2_s1], dim=2) s2_coatt_features = torch.cat([p_s2_coattentioned, layer2_s2], dim=2) s1_coatt_features = self.dropout_layer(s1_coatt_features) s2_coatt_features = self.dropout_layer(s2_coatt_features) s1_layer2_out = torch_util.auto_rnn(self.lstm_2, s1_coatt_features, l1) s2_layer2_out = torch_util.auto_rnn(self.lstm_2, s2_coatt_features, l2) s1_lay2_maxout = torch_util.max_along_time(s1_layer2_out, l1) s2_lay2_maxout = torch_util.max_along_time(s2_layer2_out, l2) features = torch.cat([ s1_lay2_maxout, s2_lay2_maxout, torch.abs(s1_lay2_maxout - s2_lay2_maxout), s1_lay2_maxout * s2_lay2_maxout ], dim=1) return self.classifier(features)
def forward(self, input_ids, attention_mask, labels=None): # if self.max_l: # l1 = l1.clamp(max=self.max_l) # l2 = l2.clamp(max=self.max_l) # if s1.size(0) > self.max_l: # s1 = s1[:self.max_l, :] # if s2.size(0) > self.max_l: # s2 = s2[:self.max_l, :] batch_l_1 = torch.sum(attention_mask, dim=1) # p_s1 = self.Embd(s1) embedding_1 = self.Embd(input_ids) s1_layer1_out = torch_util.auto_rnn(self.lstm, embedding_1, batch_l_1) # s2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2) # Length truncate # len1 = s1_layer1_out.size(0) # len2 = s2_layer1_out.size(0) # p_s1 = p_s1[:len1, :, :] # p_s2 = p_s2[:len2, :, :] # Using high way s1_layer2_in = torch.cat([embedding_1, s1_layer1_out], dim=2) # s2_layer2_in = torch.cat([p_s2, s2_layer1_out], dim=2) s1_layer2_out = torch_util.auto_rnn(self.lstm_1, s1_layer2_in, batch_l_1) # s2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s2_layer2_in, l2) s1_layer3_in = torch.cat([embedding_1, s1_layer1_out + s1_layer2_out], dim=2) # s2_layer3_in = torch.cat([p_s2, s2_layer1_out + s2_layer2_out], dim=2) s1_layer3_out = torch_util.auto_rnn(self.lstm_2, s1_layer3_in, batch_l_1) # s2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s2_layer3_in, l2) s1_layer3_maxout = torch_util.max_along_time(s1_layer3_out, batch_l_1) # s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2) # Only use the last layer # features = torch.cat([s1_layer3_maxout, s2_layer3_maxout, # torch.abs(s1_layer3_maxout - s2_layer3_maxout), # s1_layer3_maxout * s2_layer3_maxout], # dim=1) features = torch.cat([s1_layer3_maxout], dim=1) logits = self.classifier(features) loss = None if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return (loss, logits)
def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2, l2): # [B, T] p_s1 = self.dropout_layer(layer1_s1) p_s2 = self.dropout_layer(layer1_s2) s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1) s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2) s1_lay2_maxout = torch_util.max_along_time(s1_layer1_out, l1) s2_lay2_maxout = torch_util.max_along_time(s2_layer1_out, l2) features = torch.cat([ s1_lay2_maxout, s2_lay2_maxout, torch.abs(s1_lay2_maxout - s2_lay2_maxout), s1_lay2_maxout * s2_lay2_maxout ], dim=1) return self.classifier(features)
def forward(self, input_ids, token_type_ids=None, attention_mask=None): # Precomputing of the max_context_length is important # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible. encoded_layers, pooled_output = self.bert_encoder( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True) selected_output_layers = encoded_layers[-self.num_of_out_layers:] context_length = attention_mask.sum(dim=1) output_layer_list = [] for i, output_layer in enumerate(selected_output_layers): output_layer_list.append( torch_util.max_along_time( output_layer, context_length)) # [B, T, D] -> [B, D] packed_output = torch.cat(output_layer_list, dim=1) return packed_output
def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2, l2, s1_span_obj, p_weights): # [B, T] p_s1 = self.dropout_layer(layer1_s1) p_s2 = self.dropout_layer(layer1_s2) s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1) s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2) # sentence wise alignment start: s1_span_output, s1_span_output_l, s1_m_l, s1_m_span_obj = span_tool.cut_paragraph_to_sentence( s1_layer1_out, l1, s1_span_obj, max_sentence_length=self.max_span_l) s1_m_layer1_output, s1_m_layer1_l = span_tool.merge_sentence_to_paragraph( s1_span_output, s1_span_output_l, s1_m_span_obj, s1_m_l) s2_span_output, s2_span_output_l, s2_m_l, s2_m_span_obj = span_tool.replicate_query_for_span_align( s2_layer1_out, l2, s1_m_span_obj) s2_m_layer1_output, s2_m_layer1_l = span_tool.merge_sentence_to_paragraph( s2_span_output, s2_span_output_l, s2_m_span_obj, s2_m_l) # Alignment S = self.bidaf.similarity(s1_span_output, s1_span_output_l, s2_span_output, s2_span_output_l) s1_att_span, s2_att_span = self.bidaf.get_both_tile( S, s1_span_output, s2_span_output) s1_att_output, _ = span_tool.merge_sentence_to_paragraph( s1_att_span, s1_span_output_l, s1_m_span_obj, s1_m_l) s2_att_output, _ = span_tool.merge_sentence_to_paragraph( s2_att_span, s2_span_output_l, s2_m_span_obj, s2_m_l) s1_coattentioned = torch.cat([ s1_m_layer1_output, s1_att_output, s1_m_layer1_output - s1_att_output, s1_m_layer1_output * s1_att_output ], dim=2) s2_coattentioned = torch.cat([ s2_m_layer1_output, s2_att_output, s2_m_layer1_output - s2_att_output, s2_m_layer1_output * s2_att_output ], dim=2) p_s1_coattentioned = F.relu(self.projection(s1_coattentioned)) p_s2_coattentioned = F.relu(self.projection(s2_coattentioned)) s1_coatt_features = torch.cat([ p_s1_coattentioned, span_tool.quick_truncate( layer2_s1, l1, s1_span_obj, self.max_span_l, mode='paragraph')[0] ], dim=2) s2_coatt_features = torch.cat([ p_s2_coattentioned, span_tool.quick_truncate( layer2_s2, l2, s2_m_span_obj, self.max_span_l, mode='query')[0] ], dim=2) s1_coatt_features = self.dropout_layer(s1_coatt_features) s2_coatt_features = self.dropout_layer(s2_coatt_features) s1_layer2_out = torch_util.auto_rnn(self.lstm_2, s1_coatt_features, s1_m_layer1_l) s2_layer2_out = torch_util.auto_rnn(self.lstm_2, s2_coatt_features, s2_m_layer1_l) # Span weighted pooling s1_span_pooling_output, s1_span_pooling_l = span_tool.weighted_max_pooling_over_span( s1_layer2_out, s1_m_layer1_l, s1_m_span_obj) s2_span_pooling_output, s2_span_pooling_l = span_tool.weighted_max_pooling_over_span( s2_layer2_out, s2_m_layer1_l, s2_m_span_obj) weight_tensor, weight_l = span_tool.convert_input_weight_list_to_tensor( p_weights, s1_m_span_obj, s1_span_pooling_output.device) assert torch.equal(s1_span_pooling_l, weight_l) assert torch.equal(s2_span_pooling_l, weight_l) s1_span_pooling_output = s1_span_pooling_output * weight_tensor.unsqueeze( -1) s2_span_pooling_output = s2_span_pooling_output * weight_tensor.unsqueeze( -1) # weight pooling ends s1_lay2_maxout = torch_util.max_along_time(s1_span_pooling_output, s1_span_pooling_l) s2_lay2_maxout = torch_util.max_along_time(s2_span_pooling_output, s2_span_pooling_l) features = torch.cat([ s1_lay2_maxout, s2_lay2_maxout, torch.abs(s1_lay2_maxout - s2_lay2_maxout), s1_lay2_maxout * s2_lay2_maxout ], dim=1) return self.classifier(features)
def span_maxpool(input_seq, span): # [B, T, D] selected_seq, selected_length = span_util.span_select(input_seq, span) # [B, T, D] maxout_r = torch_util.max_along_time(selected_seq, selected_length) return maxout_r