Beispiel #1
0
def lens_to_mask(lens: torch.IntTensor) -> torch.Tensor:
    """
    Create a 2-D mask tensor of shape (batch_size, max_length) and dtype float32
    from a 1-D tensor of integers describing the length of batch samples in another tensor.
    """
    mask = lens.new_zeros(lens.shape[0], max(lens), dtype=torch.float32)
    for i, num in enumerate(lens):
        mask[i, :num] = 1.0
    return mask
Beispiel #2
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask)

        print('\nBert wordpiece size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=0)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1)
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()

        # gate prediction
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask,
                                                                         self._gate_self_attention_layer,
                                                                         self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
                          sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = (gate >= 0.3).long()
        gate = gate.view(batch_size, num_spans)

        output_dict = {
            "pred_sent_labels": gate, #[B, num_span]
            "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # Compute the loss for training.
        try:
            #loss = strong_sup_loss
            self._loss_trackers['loss'](loss)
            output_dict["loss"] = loss
        except RuntimeError:
            print('\n meta_data:', metadata)
            print(span_start_logits.shape)

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sp.append(metadata[i]['token_spans_sp'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_gate = gate[i].detach().cpu().numpy()
                    if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                    else:
                        self.evd_ans_metric(0)
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
Beispiel #3
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            evd_chain_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (
                1 - sent_labels_mask)

        # word + char embedding
        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        # mask
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        # BiDAF for answer predicion
        ques_output = self._dropout(
            self._phrase_layer(embedded_question, ques_mask))
        context_output = self._dropout(
            self._phrase_layer(embedded_passage, context_mask))

        modeled_passage, _, qc_score = self.qc_att(context_output, ques_output,
                                                   ques_mask)

        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        # BiDAF for gate prediction
        ques_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_question, ques_mask))
        context_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_passage, context_mask))

        modeled_passage_sp, _, qc_score_sp = self.qc_att_sp(
            context_output_sp, ques_output_sp, ques_mask)

        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)

        # gate prediction
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        spans_rep, _ = convert_sequence_to_spans(modeled_passage,
                                                 sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(
            spans_rep_sp, spans_mask, self._gate_self_attention_layer,
            self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()

        strong_sup_loss = F.nll_loss(
            F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
            sent_labels.long().view(batch_size * num_spans),
            ignore_index=-1)

        gate = (gate >= 0.3).long()
        spans_rep = spans_rep * gate.unsqueeze(-1).float()
        attended_sent_embeddings = convert_span_to_sequence(
            modeled_passage_sp, spans_rep, spans_mask)

        modeled_passage = attended_sent_embeddings + modeled_passage

        self_att_passage = self._self_attention_layer(modeled_passage,
                                                      mask=context_mask)
        modeled_passage = modeled_passage + self_att_passage[0]
        self_att_score = self_att_passage[2]

        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(
            2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(
            2) - 1e30 * (1 - context_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start],
                                dim=2)
        output_type = torch.max(output_type, 1)[0]
        # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span": best_span,
            "pred_sent_labels": gate.view(batch_size,
                                          num_spans),  #[B, num_span]
            "gate_probs":
            pred_sent_probs[:, 1].view(batch_size, num_spans),  #[B, num_span]
        }
        if self._output_att_scores:
            if not qc_score is None:
                output_dict['qc_score'] = qc_score
            if not qc_score_sp is None:
                output_dict['qc_score_sp'] = qc_score_sp
            if not self_att_score is None:
                output_dict['self_attention_score'] = self_att_score
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the loss for training.
        if span_start is not None:
            try:
                start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, None),
                    span_start.squeeze(-1))
                end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits, None),
                    span_end.squeeze(-1))
                type_loss = nll_loss(
                    util.masked_log_softmax(predict_type, None), q_type)
                loss = start_loss + end_loss + type_loss + strong_sup_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['start_loss'](start_loss)
                self._loss_trackers['end_loss'](end_loss)
                self._loss_trackers['type_loss'](type_loss)
                self._loss_trackers['strong_sup_loss'](strong_sup_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(span_start_logits.shape)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            count_yes = 0
            count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                if type_predicts[i] == 1:
                    best_span_string = 'yes'
                    count_yes += 1
                elif type_predicts[i] == 2:
                    best_span_string = 'no'
                    count_no += 1
                else:
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                if answer_texts:
                    self._squad_metrics(best_span_string.lower(), answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([
                    s_idx - 1
                    for s_idx in metadata[i]['evd_possible_chains'][0]
                    if s_idx > 0
                ])
                ans_sent_idxs.append(
                    [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']])
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1),
                             gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict