def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2,
                l2):  # [B, T]

        p_s1 = self.dropout_layer(layer1_s1)
        p_s2 = self.dropout_layer(layer1_s2)

        s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1)
        s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2)

        S = self.bidaf.similarity(s1_layer1_out, l1, s2_layer1_out, l2)
        s1_att, s2_att = self.bidaf.get_both_tile(S, s1_layer1_out,
                                                  s2_layer1_out)

        s1_coattentioned = torch.cat([
            s1_layer1_out, s1_att, s1_layer1_out - s1_att,
            s1_layer1_out * s1_att
        ],
                                     dim=2)

        s2_coattentioned = torch.cat([
            s2_layer1_out, s2_att, s2_layer1_out - s2_att,
            s2_layer1_out * s2_att
        ],
                                     dim=2)

        p_s1_coattentioned = F.relu(self.projection(s1_coattentioned))
        p_s2_coattentioned = F.relu(self.projection(s2_coattentioned))

        s1_coatt_features = torch.cat([p_s1_coattentioned, layer2_s1], dim=2)
        s2_coatt_features = torch.cat([p_s2_coattentioned, layer2_s2], dim=2)

        s1_coatt_features = self.dropout_layer(s1_coatt_features)
        s2_coatt_features = self.dropout_layer(s2_coatt_features)

        s1_layer2_out = torch_util.auto_rnn(self.lstm_2, s1_coatt_features, l1)
        s2_layer2_out = torch_util.auto_rnn(self.lstm_2, s2_coatt_features, l2)

        s1_lay2_maxout = torch_util.max_along_time(s1_layer2_out, l1)
        s2_lay2_maxout = torch_util.max_along_time(s2_layer2_out, l2)

        features = torch.cat([
            s1_lay2_maxout, s2_lay2_maxout,
            torch.abs(s1_lay2_maxout - s2_lay2_maxout),
            s1_lay2_maxout * s2_lay2_maxout
        ],
                             dim=1)

        return self.classifier(features)
Esempio n. 2
0
    def forward(self, input_ids, attention_mask, labels=None):
        # if self.max_l:
        #     l1 = l1.clamp(max=self.max_l)
        #     l2 = l2.clamp(max=self.max_l)
        #     if s1.size(0) > self.max_l:
        #         s1 = s1[:self.max_l, :]
        #     if s2.size(0) > self.max_l:
        #         s2 = s2[:self.max_l, :]
        batch_l_1 = torch.sum(attention_mask, dim=1)

        # p_s1 = self.Embd(s1)
        embedding_1 = self.Embd(input_ids)

        s1_layer1_out = torch_util.auto_rnn(self.lstm, embedding_1, batch_l_1)
        # s2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2)

        # Length truncate
        # len1 = s1_layer1_out.size(0)
        # len2 = s2_layer1_out.size(0)
        # p_s1 = p_s1[:len1, :, :]
        # p_s2 = p_s2[:len2, :, :]

        # Using high way
        s1_layer2_in = torch.cat([embedding_1, s1_layer1_out], dim=2)
        # s2_layer2_in = torch.cat([p_s2, s2_layer1_out], dim=2)

        s1_layer2_out = torch_util.auto_rnn(self.lstm_1, s1_layer2_in, batch_l_1)
        # s2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s2_layer2_in, l2)

        s1_layer3_in = torch.cat([embedding_1, s1_layer1_out + s1_layer2_out], dim=2)
        # s2_layer3_in = torch.cat([p_s2, s2_layer1_out + s2_layer2_out], dim=2)

        s1_layer3_out = torch_util.auto_rnn(self.lstm_2, s1_layer3_in, batch_l_1)
        # s2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s2_layer3_in, l2)

        s1_layer3_maxout = torch_util.max_along_time(s1_layer3_out, batch_l_1)
        # s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2)

        # Only use the last layer
        # features = torch.cat([s1_layer3_maxout, s2_layer3_maxout,
        #                       torch.abs(s1_layer3_maxout - s2_layer3_maxout),
        #                       s1_layer3_maxout * s2_layer3_maxout],
        #                      dim=1)

        features = torch.cat([s1_layer3_maxout],
                             dim=1)

        logits = self.classifier(features)

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return (loss, logits)
    def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2,
                l2):  # [B, T]

        p_s1 = self.dropout_layer(layer1_s1)
        p_s2 = self.dropout_layer(layer1_s2)

        s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1)
        s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2)

        s1_lay2_maxout = torch_util.max_along_time(s1_layer1_out, l1)
        s2_lay2_maxout = torch_util.max_along_time(s2_layer1_out, l2)

        features = torch.cat([
            s1_lay2_maxout, s2_lay2_maxout,
            torch.abs(s1_lay2_maxout - s2_lay2_maxout),
            s1_lay2_maxout * s2_lay2_maxout
        ],
                             dim=1)

        return self.classifier(features)
Esempio n. 4
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        # Precomputing of the max_context_length is important
        # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
        encoded_layers, pooled_output = self.bert_encoder(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=True)
        selected_output_layers = encoded_layers[-self.num_of_out_layers:]
        context_length = attention_mask.sum(dim=1)

        output_layer_list = []
        for i, output_layer in enumerate(selected_output_layers):
            output_layer_list.append(
                torch_util.max_along_time(
                    output_layer, context_length))  # [B, T, D] -> [B, D]

        packed_output = torch.cat(output_layer_list, dim=1)

        return packed_output
    def forward(self, layer1_s1, layer2_s1, l1, layer1_s2, layer2_s2, l2,
                s1_span_obj, p_weights):  # [B, T]

        p_s1 = self.dropout_layer(layer1_s1)
        p_s2 = self.dropout_layer(layer1_s2)

        s1_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s1, l1)
        s2_layer1_out = torch_util.auto_rnn(self.lstm_1, p_s2, l2)

        # sentence wise alignment start:
        s1_span_output, s1_span_output_l, s1_m_l, s1_m_span_obj = span_tool.cut_paragraph_to_sentence(
            s1_layer1_out,
            l1,
            s1_span_obj,
            max_sentence_length=self.max_span_l)

        s1_m_layer1_output, s1_m_layer1_l = span_tool.merge_sentence_to_paragraph(
            s1_span_output, s1_span_output_l, s1_m_span_obj, s1_m_l)

        s2_span_output, s2_span_output_l, s2_m_l, s2_m_span_obj = span_tool.replicate_query_for_span_align(
            s2_layer1_out, l2, s1_m_span_obj)

        s2_m_layer1_output, s2_m_layer1_l = span_tool.merge_sentence_to_paragraph(
            s2_span_output, s2_span_output_l, s2_m_span_obj, s2_m_l)

        # Alignment
        S = self.bidaf.similarity(s1_span_output, s1_span_output_l,
                                  s2_span_output, s2_span_output_l)
        s1_att_span, s2_att_span = self.bidaf.get_both_tile(
            S, s1_span_output, s2_span_output)

        s1_att_output, _ = span_tool.merge_sentence_to_paragraph(
            s1_att_span, s1_span_output_l, s1_m_span_obj, s1_m_l)
        s2_att_output, _ = span_tool.merge_sentence_to_paragraph(
            s2_att_span, s2_span_output_l, s2_m_span_obj, s2_m_l)

        s1_coattentioned = torch.cat([
            s1_m_layer1_output, s1_att_output, s1_m_layer1_output -
            s1_att_output, s1_m_layer1_output * s1_att_output
        ],
                                     dim=2)

        s2_coattentioned = torch.cat([
            s2_m_layer1_output, s2_att_output, s2_m_layer1_output -
            s2_att_output, s2_m_layer1_output * s2_att_output
        ],
                                     dim=2)

        p_s1_coattentioned = F.relu(self.projection(s1_coattentioned))
        p_s2_coattentioned = F.relu(self.projection(s2_coattentioned))

        s1_coatt_features = torch.cat([
            p_s1_coattentioned,
            span_tool.quick_truncate(
                layer2_s1, l1, s1_span_obj, self.max_span_l,
                mode='paragraph')[0]
        ],
                                      dim=2)
        s2_coatt_features = torch.cat([
            p_s2_coattentioned,
            span_tool.quick_truncate(
                layer2_s2, l2, s2_m_span_obj, self.max_span_l, mode='query')[0]
        ],
                                      dim=2)

        s1_coatt_features = self.dropout_layer(s1_coatt_features)
        s2_coatt_features = self.dropout_layer(s2_coatt_features)

        s1_layer2_out = torch_util.auto_rnn(self.lstm_2, s1_coatt_features,
                                            s1_m_layer1_l)
        s2_layer2_out = torch_util.auto_rnn(self.lstm_2, s2_coatt_features,
                                            s2_m_layer1_l)

        # Span weighted pooling
        s1_span_pooling_output, s1_span_pooling_l = span_tool.weighted_max_pooling_over_span(
            s1_layer2_out, s1_m_layer1_l, s1_m_span_obj)
        s2_span_pooling_output, s2_span_pooling_l = span_tool.weighted_max_pooling_over_span(
            s2_layer2_out, s2_m_layer1_l, s2_m_span_obj)
        weight_tensor, weight_l = span_tool.convert_input_weight_list_to_tensor(
            p_weights, s1_m_span_obj, s1_span_pooling_output.device)
        assert torch.equal(s1_span_pooling_l, weight_l)
        assert torch.equal(s2_span_pooling_l, weight_l)

        s1_span_pooling_output = s1_span_pooling_output * weight_tensor.unsqueeze(
            -1)
        s2_span_pooling_output = s2_span_pooling_output * weight_tensor.unsqueeze(
            -1)
        # weight pooling ends

        s1_lay2_maxout = torch_util.max_along_time(s1_span_pooling_output,
                                                   s1_span_pooling_l)
        s2_lay2_maxout = torch_util.max_along_time(s2_span_pooling_output,
                                                   s2_span_pooling_l)

        features = torch.cat([
            s1_lay2_maxout, s2_lay2_maxout,
            torch.abs(s1_lay2_maxout - s2_lay2_maxout),
            s1_lay2_maxout * s2_lay2_maxout
        ],
                             dim=1)

        return self.classifier(features)
Esempio n. 6
0
 def span_maxpool(input_seq, span):  # [B, T, D]
     selected_seq, selected_length = span_util.span_select(input_seq, span)  # [B, T, D]
     maxout_r = torch_util.max_along_time(selected_seq, selected_length)
     return maxout_r