예제 #1
0
    def forward(self,  # type: ignore
                task_index: torch.IntTensor,
                reverse: torch.ByteTensor,
                epoch_trained: torch.IntTensor,
                for_training: torch.ByteTensor,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                text_id: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embeddeds = self._encoder(task_index, tokens, epoch_trained, self._valid_discriminator, reverse, for_training,
                                  text_id)
        batch_size = get_batch_size(embeddeds["embedded_text"])

        sentiment_logits = self._sentiment_discriminator(embeddeds["embedded_text"])

        p_domain_logits = self._p_domain_discriminator(embeddeds["private_embedding"])

        # TODO set reverse = true
        s_domain_logits = self._s_domain_discriminator(embeddeds["share_embedding"], reverse=reverse)

        logits = [sentiment_logits, p_domain_logits, s_domain_logits]

        # domain_logits = self._domain_discriminator(embedded_text)
        output_dict = {'logits': sentiment_logits}
        if label is not None:
            loss = self._loss(sentiment_logits, label)
            # task_index = task_index.unsqueeze(0)
            task_index = task_index.expand(batch_size)
            # targets = [label, label, label, task_index, task_index]
            # print(p_domain_logits.shape, task_index, task_index.shape)
            p_domain_loss = self._domain_loss(p_domain_logits, task_index)
            s_domain_loss = self._domain_loss(s_domain_logits, task_index)
            # logger.info("Share domain logits standard variation is {}",
            #             torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1)))
            output_dict["tokens"] = tokens
            output_dict['stm_loss'] = loss
            output_dict['p_d_loss'] = p_domain_loss
            output_dict['s_d_loss'] = s_domain_loss
            # TODO add share domain logits std loss
            output_dict['loss'] = loss + 0.06 * p_domain_loss + 0.04 * s_domain_loss

            for (metric_name, metric) in zip(self.metrics.keys(), self.metrics.values()):
                if "auc" in metric_name:
                    metric(self.decode(output_dict)["label"], label)
                    continue
                metric(sentiment_logits, label)
        print("for training", for_training)
        if not for_training:
            with open("class_probabilities.txt", "a", encoding="utf8") as f:
                f.write(f"Task: {TASKS_NAME[task_index[0].detach()]}\nLine ID: ")
                f.write(" ".join(list(map(str, text_id.cpu().detach().numpy()))))
                f.write("\nProb: ")
                f.write(" ".join(list(map(str, F.softmax(sentiment_logits, dim=-1).cpu().detach().numpy()))))
                f.write("\nLabel: " + " ".join(list(map(str, label.cpu().detach().numpy()))) + "\n")
                f.write("\n\n\n")
        return output_dict
예제 #2
0
    def forward(self, x: torch.FloatTensor, input_lengths: torch.IntTensor):
        # conv -> relu -> dropout
        for conv in self.convolutions:
            x = conv(x)
            x = F.relu(x)
            x = F.dropout(x, 0.5, self.training)
        # TODO 这里在干嘛
        x = x.transpose(1, 2)

        input_lengths = input_lengths.cpu().numpy()
        # pack the padded sequence for lstm
        x = nn.utils.rnn.pack_padded_sequence(x,
                                              input_lengths,
                                              batch_first=True)
        # lstm
        self.lstm.flatten_parameters()
        out, _ = self.lstm(x)
        # recover from packed sequence
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        return out
예제 #3
0
    def forward(self,
                task_index: torch.IntTensor,
                tokens: Dict[str, torch.LongTensor],
                epoch_trained: torch.IntTensor,
                valid_discriminator: Discriminator,
                reverse: torch.ByteTensor,
                for_training: torch.ByteTensor,
                text_id: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        embedded_text_input = self._text_field_embedder(tokens)
        tokens_mask = util.get_text_field_mask(tokens)
        output_dict = dict()
        embedded_text_input = self._input_dropout(embedded_text_input)

        shared_encoded_text = self._shared_encoder(embedded_text_input, tokens_mask)
        shared_encoded_text, s_weights = self._s_att(self._s_query, shared_encoded_text, shared_encoded_text)

        # shared_encoded_text = self._seq2vec(shared_encoded_text, tokens_mask)
        # shared_encoded_text = get_final_encoder_states(shared_encoded_text, tokens_mask, bidirectional=True)
        output_dict["share_embedding"] = shared_encoded_text.squeeze()

        private_encoded_text = self._private_encoder(embedded_text_input, tokens_mask)
        private_encoded_text, p_weights = self._p_att(self._s_query, private_encoded_text, private_encoded_text)
        print(p_weights.shape)
        # private_encoded_text = self._seq2vec(private_encoded_text, tokens_mask)
        # private_encoded_text = get_final_encoder_states(private_encoded_text, tokens_mask, bidirectional=True)
        output_dict["private_embedding"] = private_encoded_text.squeeze()

        if not for_training:
            with open("attn.txt", "a") as f:
                f.write(f"Task: {TASKS_NAME[task_index.cpu().item()]}\nLine ID: ")
                f.write(" ".join(list(map(str, text_id.cpu().detach().numpy()))))
                f.write("\nShared Encoder Att: ")
                f.write(" ".join(list(map(str, s_weights.squeeze().cpu().detach().numpy()))))
                f.write("\nPrivate Encoder Att: ")
                f.write(" ".join(list(map(str, p_weights.squeeze().cpu().detach().numpy()))))
                f.write("\n\n\n")
        embedded_text = torch.cat([shared_encoded_text, private_encoded_text], -1).squeeze(1)
        output_dict["embedded_text"] = embedded_text
        return output_dict
예제 #4
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask)

        print('\nBert wordpiece size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=0)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1)
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()

        # gate prediction
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask,
                                                                         self._gate_self_attention_layer,
                                                                         self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
                          sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = (gate >= 0.3).long()
        gate = gate.view(batch_size, num_spans)

        output_dict = {
            "pred_sent_labels": gate, #[B, num_span]
            "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # Compute the loss for training.
        try:
            #loss = strong_sup_loss
            self._loss_trackers['loss'](loss)
            output_dict["loss"] = loss
        except RuntimeError:
            print('\n meta_data:', metadata)
            print(span_start_logits.shape)

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sp.append(metadata[i]['token_spans_sp'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_gate = gate[i].detach().cpu().numpy()
                    if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                    else:
                        self.evd_ans_metric(0)
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                transition_mask: torch.IntTensor = None,
                start_transition_mask: torch.Tensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # In this model, we only take the first chain in ``evd_chain_labels`` for supervision
        evd_chain_labels = evd_chain_labels[:, 0] if not evd_chain_labels is None else None
        # there may be some instances that we can't find any evd chain for training
        # In that case, use the mask to ignore those instances
        evd_instance_mask = (evd_chain_labels[:, 0] != 0).float() if not evd_chain_labels is None else None
        #print('passage size:', passage['bert'].shape)
        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        print('\nBert wordpiece size:', passage['bert'].shape)
        embedded_question = self._text_field_embedder(question)
        # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage, )
        # print('\npassage size:', embedded_passage.shape)
        #embedded_question = self._bert_projection(embedded_question)
        #embedded_passage = self._bert_projection(embedded_passage)
        #print('size embedded_passage:', embedded_passage.shape)
        # mask
        ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float()
        #print(context_mask.shape)
        # chain prediction
        # Shape(all_predictions): (batch_size, num_decoding_steps)
        # Shape(all_logprobs): (batch_size, num_decoding_steps)
        # Shape(seq_logprobs): (batch_size,)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(gate_probs): (batch_size * num_spans, 1)
        # Shape(gate_mask): (batch_size, num_spans)
        # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans)
        # Shape(orders): (batch_size, K, num_spans)
        all_predictions,    \
        all_logprobs,       \
        seq_logprobs,       \
        gate,               \
        gate_probs,         \
        gate_mask,          \
        g_att_score,        \
        orders = self._span_gate(embedded_passage, context_mask,
                                 embedded_question, ques_mask,
                                 evd_chain_labels,
                                 self._gate_self_attention_layer,
                                 self._gate_sent_encoder,
                                 transition_mask,
                                 start_transition_mask)
        batch_size, num_spans, max_batch_span_width = context_mask.size()

        output_dict = {
            "pred_sent_labels": gate.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "gate_probs": gate_probs.squeeze(1).view(batch_size, num_spans), #[B, num_span]
            "pred_sent_orders": orders, #[B, K, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # compute evd rl training metric, rewards, and loss
        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        evd_TP, evd_NP, evd_NT = self._f1_metrics(gate.squeeze(1).view(batch_size, num_spans),
                                                  sent_labels,
                                                  mask=gate_mask,
                                                  instance_mask=evd_instance_mask if self.training else None,
                                                  sum=False)
        # print("TP:", evd_TP)
        # print("NP:", evd_NP)
        # print("NT:", evd_NT)
        evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13)
        evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13)
        evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13))
        predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1), eos_idx=0)[0]
        gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0]
        # default to take multiple predicted chains, so unsqueeze dim 1
        self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1), gold_labels=evd_chain_labels,
                                predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask)
        print("gold chain:", evd_chain_labels)
        predict_mask = predict_mask.float().squeeze(1)
        rl_loss = -torch.mean(torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None], dim=1))
        # torch.cuda.empty_cache()
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # Compute before loss for rl
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            #token_spans_sp = []
            #token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            pred_chains_include_ans = []
            beam_pred_chains_include_ans = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_sent_tokens'])
                #token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                #offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                print("ans_sent_idxs:", metadata[i]['ans_sent_idxs'])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_orders = orders[i].detach().cpu().numpy()
                    if any([pred_sent_orders[0][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                        pred_chains_include_ans.append(1)
                    else:
                        self.evd_ans_metric(0)
                        pred_chains_include_ans.append(0)
                    if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) 
                                                                      for beam in range(len(pred_sent_orders))]):
                        self.evd_beam_ans_metric(1)
                        beam_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam_ans_metric(0)
                        beam_pred_chains_include_ans.append(0)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_sent_tokens'] = passage_tokens
            #output_dict['token_spans_sp'] = token_spans_sp
            #output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['pred_chains_include_ans'] = pred_chains_include_ans
            output_dict['beam_pred_chains_include_ans'] = beam_pred_chains_include_ans
            output_dict['_id'] = ids

        # Compute the loss for training.
        if evd_chain_labels is not None:
            try:
                loss = rl_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['rl_loss'](rl_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(output_dict['_id'])

        return output_dict
예제 #6
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            evd_chain_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (
                1 - sent_labels_mask)

        # word + char embedding
        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        # mask
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        # BiDAF for answer predicion
        ques_output = self._dropout(
            self._phrase_layer(embedded_question, ques_mask))
        context_output = self._dropout(
            self._phrase_layer(embedded_passage, context_mask))

        modeled_passage, _, qc_score = self.qc_att(context_output, ques_output,
                                                   ques_mask)

        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        # BiDAF for gate prediction
        ques_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_question, ques_mask))
        context_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_passage, context_mask))

        modeled_passage_sp, _, qc_score_sp = self.qc_att_sp(
            context_output_sp, ques_output_sp, ques_mask)

        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)

        # gate prediction
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        spans_rep, _ = convert_sequence_to_spans(modeled_passage,
                                                 sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(
            spans_rep_sp, spans_mask, self._gate_self_attention_layer,
            self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()

        strong_sup_loss = F.nll_loss(
            F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
            sent_labels.long().view(batch_size * num_spans),
            ignore_index=-1)

        gate = (gate >= 0.3).long()
        spans_rep = spans_rep * gate.unsqueeze(-1).float()
        attended_sent_embeddings = convert_span_to_sequence(
            modeled_passage_sp, spans_rep, spans_mask)

        modeled_passage = attended_sent_embeddings + modeled_passage

        self_att_passage = self._self_attention_layer(modeled_passage,
                                                      mask=context_mask)
        modeled_passage = modeled_passage + self_att_passage[0]
        self_att_score = self_att_passage[2]

        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(
            2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(
            2) - 1e30 * (1 - context_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start],
                                dim=2)
        output_type = torch.max(output_type, 1)[0]
        # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span": best_span,
            "pred_sent_labels": gate.view(batch_size,
                                          num_spans),  #[B, num_span]
            "gate_probs":
            pred_sent_probs[:, 1].view(batch_size, num_spans),  #[B, num_span]
        }
        if self._output_att_scores:
            if not qc_score is None:
                output_dict['qc_score'] = qc_score
            if not qc_score_sp is None:
                output_dict['qc_score_sp'] = qc_score_sp
            if not self_att_score is None:
                output_dict['self_attention_score'] = self_att_score
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the loss for training.
        if span_start is not None:
            try:
                start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, None),
                    span_start.squeeze(-1))
                end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits, None),
                    span_end.squeeze(-1))
                type_loss = nll_loss(
                    util.masked_log_softmax(predict_type, None), q_type)
                loss = start_loss + end_loss + type_loss + strong_sup_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['start_loss'](start_loss)
                self._loss_trackers['end_loss'](end_loss)
                self._loss_trackers['type_loss'](type_loss)
                self._loss_trackers['strong_sup_loss'](strong_sup_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(span_start_logits.shape)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            count_yes = 0
            count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                if type_predicts[i] == 1:
                    best_span_string = 'yes'
                    count_yes += 1
                elif type_predicts[i] == 2:
                    best_span_string = 'no'
                    count_no += 1
                else:
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                if answer_texts:
                    self._squad_metrics(best_span_string.lower(), answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([
                    s_idx - 1
                    for s_idx in metadata[i]['evd_possible_chains'][0]
                    if s_idx > 0
                ])
                ans_sent_idxs.append(
                    [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']])
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1),
                             gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
예제 #7
0
    def forward(self,  # type: ignore
                task_index: torch.IntTensor,
                reverse: torch.ByteTensor,
                for_training: torch.ByteTensor,
                train_stage: torch.IntTensor,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        :param task_index:
        :param reverse:
        :param for_training:
        :param train_stage: ["share_senti", "share_classify",
        "share_classify_adversarial", "domain_valid", "domain_valid_adversarial"]
        :param tokens:
        :param label:
        :return:
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()
        embed_tokens = self._encoder(embedded_text, mask)
        batch_size = get_batch_size(embed_tokens)
        # bs * (25*4)
        seq_vec = self._seq_vec(embed_tokens, mask)
        # TODO add linear layer

        domain_embeddings = self._domain_embeddings(torch.arange(self._de_dim).cuda())

        de_scores = F.softmax(
            self._de_attention(seq_vec, domain_embeddings.expand(batch_size, *domain_embeddings.size())), dim=1)
        de_valid = False
        if np.random.rand() < 0.3:
            de_valid = True
            noise = 0.01 * torch.normal(mean=0.5,
                                        # std=torch.std(domain_embeddings).sign_())
                                        std=torch.empty(*de_scores.size()).fill_(1.0))
            de_scores = de_scores + noise.cuda()
        domain_embedding = torch.matmul(de_scores, domain_embeddings)
        domain_embedding = self._de_feedforward(domain_embedding)
        # train sentiment classify
        if train_stage.cpu() == torch.tensor(0) or not for_training:

            de_representation = torch.tanh(torch.add(domain_embedding, seq_vec))

            sentiment_logits = self._sentiment_discriminator(de_representation)
            if label is not None:
                loss = self._loss(sentiment_logits, label)
                self.metrics["{}_stm_acc".format(TASKS_NAME[task_index.cpu()])](sentiment_logits, label)

        if train_stage.cpu() == torch.tensor(1) or not for_training:
            s_domain_logits = self._s_domain_discriminator(seq_vec, reverse=reverse)
            task_index = task_index.expand(batch_size)
            loss = self._domain_loss(s_domain_logits, task_index)
            self.metrics["s_domain_acc"](s_domain_logits, task_index)

        if train_stage.cpu() == torch.tensor(2) or not for_training:
            valid_logits = self._valid_discriminator(domain_embedding, reverse=reverse)
            valid_label = torch.ones(batch_size).cuda()
            if de_valid:
                valid_label = torch.zeros(batch_size).cuda()
            if self._label_smoothing is not None and self._label_smoothing > 0.0:
                loss = sequence_cross_entropy_with_logits(valid_logits,
                                                          valid_label.unsqueeze(0).cuda(),
                                                          torch.tensor(1).unsqueeze(0).cuda(),
                                                          average="token",
                                                          label_smoothing=self._label_smoothing)
            else:
                loss = self._valid_loss(valid_logits,
                                        torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda())
            self.metrics["valid_acc"](valid_logits, valid_label)
        # TODO add orthogonal loss
        output_dict = {"loss": loss}

        return output_dict
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                sentence_spans: torch.IntTensor = None,
                sent_labels: torch.IntTensor = None,
                evd_chain_labels: torch.IntTensor = None,
                q_type: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # In this model, we only take the first chain in ``evd_chain_labels`` for supervision
        evd_chain_labels = evd_chain_labels[:, 0]
        # there may be some instances that we can't find any evd chain for training
        # In that case, use the mask to ignore those instances
        evd_instance_mask = (evd_chain_labels[:, 0] != 0).float() if not evd_chain_labels is None else None

        # word + char embedding
        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        # mask
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        # BiDAF for answer prediction (no dropout since the answer f1 is used as the reward)
        ques_output = self._phrase_layer(embedded_question, ques_mask)
        context_output = self._phrase_layer(embedded_passage, context_mask)

        modeled_passage, _, qc_score = self.qc_att(context_output, ques_output, ques_mask)

        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        # BiDAF for chain prediction
        ques_output_sp = self._dropout(self._phrase_layer_sp(embedded_question, ques_mask))
        context_output_sp = self._dropout(self._phrase_layer_sp(embedded_passage, context_mask))

        modeled_passage_sp, _, qc_score_sp = self.qc_att_sp(context_output_sp, ques_output_sp, ques_mask)

        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp, context_mask)

        # chain prediciton
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(modeled_passage_sp, sentence_spans)
        spans_rep, _ = convert_sequence_to_spans(modeled_passage, sentence_spans)
        # Shape(all_predictions): (batch_size, K, num_decoding_steps)
        # Shape(all_logprobs): (batch_size, K, num_decoding_steps)
        # Shape(seq_logprobs): (batch_size, K)
        # Shape(gate): (batch_size * K * num_spans, 1)
        # Shape(gate_probs): (batch_size * K * num_spans, 1)
        # Shape(gate_mask): (batch_size, num_spans)
        # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans)
        # Shape(orders): (batch_size, K, num_spans)
        all_predictions,    \
        all_logprobs,       \
        seq_logprobs,       \
        gate,               \
        gate_probs,         \
        gate_mask,          \
        g_att_score,        \
        orders = self._span_gate(spans_rep_sp, spans_mask,
                                 ques_output_sp, ques_mask,
                                 evd_chain_labels,
                                 self._gate_self_attention_layer,
                                 self._gate_sent_encoder,
                                 get_all_beam=True)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        beam_size = all_predictions.size(1)

        # expand all the tensor to fit the beam size
        num_toks = modeled_passage.size(1)
        emb_dim = spans_rep.size(-1)
        spans_rep = spans_rep.reshape(batch_size, num_spans, max_batch_span_width, emb_dim)
        spans_rep = spans_rep.unsqueeze(1).expand(batch_size, beam_size, num_spans, max_batch_span_width, emb_dim)
        spans_rep = spans_rep.reshape(batch_size * beam_size * num_spans, max_batch_span_width, emb_dim)
        spans_mask = spans_mask[:, None, :, :].expand(batch_size, beam_size, num_spans, max_batch_span_width)
        spans_mask = spans_mask.reshape(batch_size * beam_size, num_spans, max_batch_span_width)
        context_mask = context_mask.unsqueeze(1).expand(batch_size, beam_size, num_toks)
        context_mask = context_mask.reshape(batch_size * beam_size, num_toks)
        se_mask = gate.expand(batch_size * beam_size * num_spans, max_batch_span_width).unsqueeze(-1)
        se_mask = convert_span_to_sequence(modeled_passage_sp, se_mask, spans_mask).squeeze(-1)

        spans_rep = spans_rep * gate.unsqueeze(-1)
        attended_sent_embeddings = convert_span_to_sequence(modeled_passage_sp, spans_rep, spans_mask)

        modeled_passage = attended_sent_embeddings

        self_att_passage = self._self_attention_layer(modeled_passage, mask=context_mask)
        modeled_passage = modeled_passage + self_att_passage[0]
        self_att_score = self_att_passage[2]

        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(2) - 1e30 * (1 - context_mask * se_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(2) - 1e30 * (1 - context_mask * se_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start], dim=2)
        output_type = torch.max(output_type, 1)[0]
        # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits.view(batch_size, beam_size, num_toks)[:, 0, :],
            "span_end_logits": span_end_logits.view(batch_size, beam_size, num_toks)[:, 0, :],
            "best_span": best_span.view(batch_size, beam_size, 2)[:, 0, :],
            "pred_sent_labels": gate.squeeze(1).view(batch_size, beam_size, num_spans)[:, 0, :], #[B, num_span]
            "gate_probs": gate_probs.squeeze(1).view(batch_size, beam_size, num_spans)[:, 0, :], #[B, num_span]
            "pred_sent_orders": orders, #[B, K, num_span]
        }
        if self._output_att_scores:
            if not qc_score is None:
                output_dict['qc_score'] = qc_score
            if not qc_score_sp is None:
                output_dict['qc_score_sp'] = qc_score_sp
            if not self_att_score is None:
                output_dict['self_attention_score'] = self_att_score
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # compute evd rl training metric, rewards, and loss
        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        evd_TP, evd_NP, evd_NT = self._f1_metrics(gate.squeeze(1).view(batch_size, beam_size, num_spans)[:, 0, :],
                                                  sent_labels,
                                                  mask=gate_mask,
                                                  instance_mask=evd_instance_mask if self.training else None,
                                                  sum=False)
        print("TP:", evd_TP)
        print("NP:", evd_NP)
        print("NT:", evd_NT)
        evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13)
        evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13)
        evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13))
        #print("evd_f1s:", evd_f1s)
        predict_mask = get_evd_prediction_mask(all_predictions[:, :1, :], eos_idx=0)[0]
        gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0]
        # ChainAccuracy defaults to take multiple predicted chains, so unsqueeze dim 1
        self.evd_sup_acc_metric(predictions=all_predictions[:, :1, :], gold_labels=evd_chain_labels,
                                predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask)
        print("gold chain:", evd_chain_labels)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # Compute before loss for rl
        best_span = best_span.view(batch_size, beam_size, 2)
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            pred_chains_include_ans = []
            beam_pred_chains_include_ans = []
            beam2_pred_chains_include_ans = []
            ids = []
            ems = []
            f1s = []
            rb_ems = []
            ch_lens = []
            #count_yes = 0
            #count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)
                beam_best_span_string = []
                beam_f1s = []
                beam_ems = []
                beam_rb_ems = []
                beam_ch_lens = []
                for b_idx in range(beam_size):
                    if type_predicts[i] == 1:
                        best_span_string = 'yes'
                        #count_yes += 1
                    elif type_predicts[i] == 2:
                        best_span_string = 'no'
                        #count_no += 1
                    else:
                        predicted_span = tuple(best_span[i, b_idx].detach().cpu().numpy())
                        start_offset = offsets[predicted_span[0]][0]
                        end_offset = offsets[predicted_span[1]][1]
                        best_span_string = passage_str[start_offset:end_offset]
                    beam_best_span_string.append(best_span_string)

                    if answer_texts:
                        em, f1 = self._squad_metrics(best_span_string.lower(), answer_texts)
                        beam_ems.append(em)
                        beam_f1s.append(f1)

                    rb_chain = [s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0]
                    pd_chain = [s_idx-1 for s_idx in all_predictions[i, b_idx].detach().cpu().numpy() if s_idx > 0]
                    beam_rb_ems.append(float(rb_chain == pd_chain))
                    beam_ch_lens.append(float(len(pd_chain)))
                ems.append(beam_ems)
                f1s.append(beam_f1s)
                rb_ems.append(beam_rb_ems)
                ch_lens.append(beam_ch_lens)
                output_dict['best_span_str'].append(beam_best_span_string[0])

                # shift sentence indice back
                evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0])
                ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']])
                print("ans_sent_idxs:", metadata[i]['ans_sent_idxs'])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_orders = orders[i].detach().cpu().numpy()
                    if any([pred_sent_orders[0][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]):
                        self.evd_ans_metric(1)
                        pred_chains_include_ans.append(1)
                    else:
                        self.evd_ans_metric(0)
                        pred_chains_include_ans.append(0)
                    if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) 
                                                                        for beam in range(len(pred_sent_orders))]):
                        self.evd_beam_ans_metric(1)
                        beam_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam_ans_metric(0)
                        beam_pred_chains_include_ans.append(0)
                    if any([any([pred_sent_orders[beam][s_idx-1] >= 0 for s_idx in metadata[i]['ans_sent_idxs']]) 
                                                                        for beam in range(2)]):
                        self.evd_beam2_ans_metric(1)
                        beam2_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam2_ans_metric(0)
                        beam2_pred_chains_include_ans.append(0)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['pred_chains_include_ans'] = pred_chains_include_ans
            output_dict['beam_pred_chains_include_ans'] = beam_pred_chains_include_ans
            output_dict['_id'] = ids

        # Compute the loss for training.
        # RL Loss equals ``-log(P) * (R - baseline)``
        # Shape: (batch_size, num_decoding_steps)
        tot_rs = 0.
        if "ans" in self._ft_reward:
            ans_rs = seq_logprobs.new_tensor(f1s) # shape: (batch_size, beam_size)
            print('ans rs:', ans_rs)
            tot_rs = tot_rs + ans_rs
        if "rb" in self._ft_reward:
            rb_rs = seq_logprobs.new_tensor(rb_ems) # shape: (batch_size, beam_size)
            print('rb rs:', rb_rs)
            tot_rs = tot_rs + 0.7 * rb_rs
        if "len" in self._ft_reward:
            len_rs = seq_logprobs.new_tensor(ch_lens) # shape: (batch_size, beam_size)
            len_rs = (1. - len_rs / 5.) * (len_rs > 0).float()
            print("len rs:", len_rs)
            tot_rs = tot_rs + 0.7 * len_rs
        #rs_baseline = torch.mean(rs)
        rs_baseline = 0#torch.mean(tot_rs)
        tot_rs = tot_rs - rs_baseline
        rl_loss = -torch.mean(seq_logprobs * tot_rs)
        if span_start is not None:
            try:
                loss = rl_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['rl_loss'](rl_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(output_dict['_id'])
                print(span_start_logits.shape)

        return output_dict
예제 #9
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            evd_chain_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # In this model, we only take the first chain in ``evd_chain_labels`` for supervision
        evd_chain_labels = evd_chain_labels[:,
                                            0] if not evd_chain_labels is None else None
        # there may be some instances that we can't find any evd chain for training
        # In that case, use the mask to ignore those instances
        evd_instance_mask = (evd_chain_labels[:, 0] != 0).float(
        ) if not evd_chain_labels is None else None

        # bert embedding for answer prediction
        # shape: [batch_size, max_q_len, emb_size]
        embedded_question = self._text_field_embedder(question)
        # shape: [batch_size, num_para, max_para_len, embedding_dim]
        embedded_passage = self._text_field_embedder(passage)

        # mask
        ques_mask = util.get_text_field_mask(question,
                                             num_wrapping_dims=0).float()
        context_mask = util.get_text_field_mask(passage,
                                                num_wrapping_dims=1).float()

        # extract word embeddings for each sentence
        batch_size, num_para, max_para_len, emb_size = embedded_passage.size()
        batch_size, num_para, max_num_sent, _ = sentence_spans.size()
        # Shape(spans_rep): (batch_size*num_para*max_num_sent, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size*num_para, max_num_sent, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            embedded_passage.view(batch_size * num_para, max_para_len,
                                  emb_size),
            sentence_spans.view(batch_size * num_para, max_num_sent, 2))
        _, _, max_batch_span_width = spans_mask.size()
        # flatten out the num_para dimension
        # shape: (batch_size, num_para, max_num_sent), specify which sent is not pad(i.e. all tok in the sent is not pad)
        sentence_mask = (spans_mask.sum(-1) > 0).float().view(
            batch_size, num_para, max_num_sent)
        # the maximum total number of sentences for each example
        max_num_global_sent = torch.max(sentence_mask.sum([1,
                                                           2])).long().item()
        num_spans = max_num_global_sent
        # shape: (batch_size, num_spans, max_batch_span_width*embedding_dim),
        # where num_spans equals to num_para * num_sent(no max bc para paddings are removed)
        # and also equals to max_num_global_sent
        spans_rep_sp = convert_span_to_sequence(
            spans_rep_sp.new_zeros((batch_size, max_num_global_sent)),
            spans_rep_sp.view(batch_size * num_para, max_num_sent,
                              max_batch_span_width * emb_size), sentence_mask)
        # shape: (batch_size * num_spans, max_batch_span_width, embedding_dim),
        spans_rep_sp = spans_rep_sp.view(batch_size * max_num_global_sent,
                                         max_batch_span_width, emb_size)

        # shape: (batch_size, num_spans, max_batch_span_width),
        spans_mask = convert_span_to_sequence(
            spans_mask.new_zeros((batch_size, max_num_global_sent)),
            spans_mask, sentence_mask)
        # chain prediction
        # Shape(all_predictions): (batch_size, num_decoding_steps)
        # Shape(all_logprobs): (batch_size, num_decoding_steps)
        # Shape(seq_logprobs): (batch_size,)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(gate_probs): (batch_size * num_spans, 1)
        # Shape(gate_mask): (batch_size, num_spans)
        # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans)
        # Shape(orders): (batch_size, K, num_spans)
        all_predictions, \
        all_logprobs, \
        seq_logprobs, \
        gate, \
        gate_probs, \
        gate_mask, \
        g_att_score, \
        orders = self._span_gate(spans_rep_sp, spans_mask,
                                 embedded_question, ques_mask,
                                 evd_chain_labels,
                                 self._gate_self_attention_layer,
                                 self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()

        output_dict = {
            "pred_sent_labels":
            gate.squeeze(1).view(batch_size, num_spans),  # [B, num_span]
            "gate_probs":
            gate_probs.squeeze(1).view(batch_size, num_spans),  # [B, num_span]
            "pred_sent_orders": orders,  # [B, K, num_span]
        }
        if self._output_att_scores:
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        # compute evd rl training metric, rewards, and loss
        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        evd_TP, evd_NP, evd_NT = self._f1_metrics(
            gate.squeeze(1).view(batch_size, num_spans),
            sent_labels,
            mask=gate_mask,
            instance_mask=evd_instance_mask if self.training else None,
            sum=False)
        print("TP:", evd_TP)
        print("NP:", evd_NP)
        print("NT:", evd_NT)
        evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13)
        evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13)
        evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13))
        predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1),
                                               eos_idx=0)[0]
        gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0]
        # default to take multiple predicted chains, so unsqueeze dim 1
        self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1),
                                gold_labels=evd_chain_labels,
                                predict_mask=predict_mask,
                                gold_mask=gold_mask,
                                instance_mask=evd_instance_mask)
        print("gold chain:", evd_chain_labels)
        predict_mask = predict_mask.float().squeeze(1)
        rl_loss = -torch.mean(
            torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None],
                      dim=1))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # Compute before loss for rl
        if metadata is not None:
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            # token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            pred_chains_include_ans = []
            beam_pred_chains_include_ans = []
            ids = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                # token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([
                    s_idx - 1
                    for s_idx in metadata[i]['evd_possible_chains'][0]
                    if s_idx > 0
                ])
                ans_sent_idxs.append(
                    [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']])
                print("ans_sent_idxs:", metadata[i]['ans_sent_idxs'])
                if len(metadata[i]['ans_sent_idxs']) > 0:
                    pred_sent_orders = orders[i].detach().cpu().numpy()
                    if any([
                            pred_sent_orders[0][s_idx - 1] >= 0
                            for s_idx in metadata[i]['ans_sent_idxs']
                    ]):
                        self.evd_ans_metric(1)
                        pred_chains_include_ans.append(1)
                    else:
                        self.evd_ans_metric(0)
                        pred_chains_include_ans.append(0)
                    if any([
                            any([
                                pred_sent_orders[beam][s_idx - 1] >= 0
                                for s_idx in metadata[i]['ans_sent_idxs']
                            ]) for beam in range(len(pred_sent_orders))
                    ]):
                        self.evd_beam_ans_metric(1)
                        beam_pred_chains_include_ans.append(1)
                    else:
                        self.evd_beam_ans_metric(0)
                        beam_pred_chains_include_ans.append(0)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            # output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['pred_chains_include_ans'] = pred_chains_include_ans
            output_dict[
                'beam_pred_chains_include_ans'] = beam_pred_chains_include_ans
            output_dict['_id'] = ids

        # Compute the loss for training.
        if evd_chain_labels is not None:
            loss = rl_loss
            self._loss_trackers['loss'](loss)
            self._loss_trackers['rl_loss'](rl_loss)
            output_dict["loss"] = loss

        return output_dict