Beispiel #1
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            labels: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            From a `TextField`
        labels : `torch.IntTensor`, optional (default = `None`)
            From a `MultiLabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.sigmoid(logits)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if labels is not None:
            loss = self._loss(logits,
                              labels.float().view(-1, self._num_labels))
            output_dict["loss"] = loss
            # TODO (John): This shouldn't be necessary as __call__ of the metrics detaches these
            # tensors anyways?
            cloned_logits, cloned_labels = logits.clone(), labels.clone()
            self._micro_f1(cloned_logits, cloned_labels)
            self._macro_f1(cloned_logits, cloned_labels)

        return output_dict
    def forward(self,
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                list_,
                passages_length: torch.LongTensor = None,
                correct_passage: torch.LongTensor = None,
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # shape: B x N x T x E
        embedded_passage_list = self._embedder(list_)
        # shape: N
        (batch_size, num_passages, max_p, embedding_size) = embedded_passage_list.size()
     
        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = embedded_passage_list.view(batch_size, -1, embedding_size)
        # embedded_passage = self._embedder(passage)

        # batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        # passage_mask = util.get_text_field_mask(passage)
        passage_list_mask = util.get_text_field_mask(list_, 1)
        passage_mask = passage_list_mask.view(batch_size, -1)

        # shape: B x T x 2H
        encoded_question = self._dropout(self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)
        #encoded_passage_list = self._dropout(self._passage_encoder(embedded_passage_list, passage_list_mask))

        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            u_t_P = encoded_passage[:, timestep, :]
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(encoded_passage[:, timestep, :], encoded_question, question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question, attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat([encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(r_Q, question_awared_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(question_awared_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(h_1, question_awared_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_mask)

        #num_passages = passages_length.size(1)
        #cum_passages = torch.cumsum(passages_length, dim=1)
        g = []
        for i in range(num_passages):
            attn_weights = self._passage_attention_for_ranking(r_Q, question_awared_passage[:, i*max_p: (i + 1)*max_p, :], passage_mask[:, i*max_p: (i + 1)*max_p])
            r_P = util.weighted_sum(question_awared_passage[:, i*max_p: (i + 1)*max_p, :], attn_weights)
            question_passage_combined = torch.cat([r_Q, r_P], dim=-1)
            gi = self._dropout(self._match_layer_2(F.tanh(self._match_layer_1(question_passage_combined))))
            g.append(gi)
        # compute r_P
        # shape: B x T = attention(B x 2H, B x T x 2H)
        #attn_weights = self._passage_attention_for_ranking(r_Q, question_awared_passage, passage_mask)
        # shape: B x 2H
        #r_P = util.weighted_sum(question_awared_passage, attn_weights)
        # shape: B x 4H
        #question_passage_combined = torch.cat([r_Q, r_P], dim=-1)
        # shape: B x 10
        #g = self._dropout(self._match_layer_2(F.tanh(self._match_layer_1(question_passage_combined))))
        #cum_passages = torch.cumsum(passages_length, dim=1)
        #for b in range(batch_size):
        #    for i in range(num_passages):
        #        attn_weights = self._passage_attention_for_ranking(r_Q[b], question_awared_passage
        
        padded_span_start = span_start.clone()
        padded_span_end = span_end.clone()
        cumsum = torch.cumsum(passage_mask.long(), dim=1)
        for b in range(batch_size):
             padded_span_start[b] = (cumsum[b] == span_start[b] + 1).nonzero()[0][0]
             padded_span_end[b] = (cumsum[b] == span_end[b] + 1).nonzero()[0][0]
            
        g = torch.cat(g, dim=1)
        passage_log_probs = F.log_softmax(g, dim=-1)

        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, padded_span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, padded_span_end.squeeze(-1))
            PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1))
            loss = self._r * AP_loss + self._r * PR_loss
            output_dict['loss'] = loss

        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #max_start = max_start.cpu().data[0]
        #max_end = max_end.cpu().data[0]
        #unpad
        for b in range(batch_size):
            max_start.data[b] = cumsum.data[b, max_start.data[b]] - 1
            max_end.data[b] = cumsum.data[b, max_end.data[b]] - 1
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end

        self._num_iter += 1
        if (self._num_iter % 50 == 0):
            print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.cpu().data[0], max_end.cpu().data[0]))

        return output_dict