Ejemplo n.º 1
0
class RnnDocReader(nn.Module):
    """Network for the Document Reader module of DrQA."""
    RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}

    def __init__(self, opt, padding_idx=0, embedding=None, normalize_emb=False,embedding_order=True):
        super(RnnDocReader, self).__init__()
        # Store config
        self.opt = opt
        '''
        # Word embeddings
        if opt['pretrained_words']:
            assert embedding is not None
            self.embedding = nn.Embedding(embedding.size(0),
                                          embedding.size(1),
                                          padding_idx=padding_idx)
            if normalize_emb: normalize_emb_(embedding)
            self.embedding.weight.data = embedding

            if opt['fix_embeddings']:
                assert opt['tune_partial'] == 0
                for p in self.embedding.parameters():
                    p.requires_grad = False
            elif opt['tune_partial'] > 0:
                assert opt['tune_partial'] + 2 < embedding.size(0)
                fixed_embedding = embedding[opt['tune_partial'] + 2:]
                self.register_buffer('fixed_embedding', fixed_embedding)
                self.fixed_embedding = fixed_embedding
        else:  # random initialized
            self.embedding = nn.Embedding(opt['vocab_size'],
                                          opt['embedding_dim'],
                                          padding_idx=padding_idx)
        '''
        if opt['pos']:
            self.pos_embedding = nn.Embedding(opt['pos_size'], opt['pos_dim'])
            if normalize_emb: normalize_emb_(self.pos_embedding.weight.data)
        if opt['ner']:
            self.ner_embedding = nn.Embedding(opt['ner_size'], opt['ner_dim'])
            if normalize_emb: normalize_emb_(self.ner_embedding.weight.data)
        # Projection for attention weighted question
        if opt['use_qemb']:
            self.qemb_match = layers.SeqAttnMatch(3 * opt['embedding_dim'])
        if opt['use_cove']:
            self.cove_embedding = MTLSTM(n_vocab=embedding.size(0),vectors=embedding.clone(),residual_embeddings=True)
            if not opt['fine_tune']:
                for p in self.cove_embedding.parameters():
                    p.requires_grad=False

        # Input size to RNN: word emb + question emb + manual features
        doc_input_size = opt['embedding_dim'] + opt['num_features']
        question_input_size = opt['embedding_dim']
        if opt['use_qemb']:
            doc_input_size += opt['embedding_dim']
        if opt['pos']:
            doc_input_size += opt['pos_dim']
        if opt['ner']:
            doc_input_size += opt['ner_dim']
        if opt['use_cove']:
            # for Cove
            doc_input_size+=2* opt['embedding_dim']
            question_input_size += 2*opt['embedding_dim']

        print('doc_input_size:',doc_input_size)
        self.attention_rnns= custom.AttentionRNN(opt,doc_input_size=doc_input_size,question_input_size=question_input_size, ratio=opt['reduction_ratio'])

        # Output sizes of rnn encoders
        doc_hidden_size = 2 * opt['hidden_size'] +opt['hidden_size']//opt['reduction_ratio']
        question_hidden_size =  2 * opt['hidden_size']+opt['hidden_size']//opt['reduction_ratio']

        # Question merging
        if opt['question_merge'] not in ['avg', 'self_attn']:
            raise NotImplementedError('question_merge = %s' % opt['question_merge'])
        if opt['question_merge'] == 'self_attn':
            self.self_attn = layers.LinearSeqAttn(question_hidden_size)


        # Bilinear attention for span start/end
        self.start_attn = layers.BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
        )
        self.end_attn = layers.BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
        )

    def forward(self, x1, x1_f, x1_pos, x1_ner, x1_mask, x2, x2_mask,x1_order,x2_order):
        """Inputs:
        x1 = document word indices             [batch * len_d]
        x1_f = document word features indices  [batch * len_d * nfeat]
        x1_pos = document POS tags             [batch * len_d]
        x1_ner = document entity tags          [batch * len_d]
        x1_mask = document padding mask        [batch * len_d]
        x2 = question word indices             [batch * len_q]
        x2_mask = question padding mask        [batch * len_q]
        """

        # Embed both document and question
        #x1_emb = self.embedding(x1)
        if self.opt['use_cove']:
            x1_emb_cove=self.cove_embedding(x1,torch.LongTensor(x1.size(0)).fill_(x1.size(1)).cuda())
        #x1_emb_order = self.embedding_order(x1_order)

        #x2_emb = self.embedding(x2)
        if self.opt['use_cove']:
            x2_emb_cove=  self.cove_embedding(x2,torch.LongTensor(x2.size(0)).fill_(x2.size(1)).cuda())
        #x2_emb += self.embedding_order(x2_order)

        '''
        if self.opt['dropout_emb'] > 0:
            x1_emb = nn.functional.dropout(x1_emb, p=self.opt['dropout_emb'],
                                               training=self.training)
            x2_emb = nn.functional.dropout(x2_emb, p=self.opt['dropout_emb'],
                                           training=self.training)

        x2_emb = torch.cat([x2_emb, x2_emb_cove], dim=2)
        x1_emb = torch.cat([x1_emb, x1_emb_cove], dim=2)
        '''

        x2_emb = x2_emb_cove
        x1_emb = x1_emb_cove


        drnn_input_list = [x1_emb, x1_f]
        # Add attention-weighted question representation
        if self.opt['use_qemb']:
            x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
            drnn_input_list.append(x2_weighted_emb)
        if self.opt['pos']:
            x1_pos_emb = self.pos_embedding(x1_pos)
            if self.opt['dropout_emb'] > 0:
                x1_pos_emb = nn.functional.dropout(x1_pos_emb, p=self.opt['dropout_emb'],
                                               training=self.training)
            drnn_input_list.append(x1_pos_emb)
        if self.opt['ner']:
            x1_ner_emb = self.ner_embedding(x1_ner)
            if self.opt['dropout_emb'] > 0:
                x1_ner_emb = nn.functional.dropout(x1_ner_emb, p=self.opt['dropout_emb'],
                                               training=self.training)
            drnn_input_list.append(x1_ner_emb)
        drnn_input = torch.cat(drnn_input_list, 2)
        #print('drnn_input:',drnn_input.size())
        # Encode document with RNN
        doc_hiddens, question_hiddens = self.attention_rnns(drnn_input,x1_mask,x2_emb,x2_mask)
        if self.opt['question_merge'] == 'avg':
            q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
        elif self.opt['question_merge'] == 'self_attn':
            q_merge_weights = self.self_attn(question_hiddens, x2_mask)
        question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)

        start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
        end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
        return start_scores, end_scores
Ejemplo n.º 2
0
class RnnDocReader(nn.Module):
    """Network for the Document Reader module of DrQA."""
    RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}

    def __init__(self,
                 opt,
                 padding_idx=0,
                 embedding=None,
                 normalize_emb=False,
                 embedding_order=True):
        super(RnnDocReader, self).__init__()
        # Store config
        self.opt = opt

        # Word embeddings
        if opt['pretrained_words']:
            assert embedding is not None
            self.embedding = nn.Embedding(embedding.size(0),
                                          embedding.size(1),
                                          padding_idx=padding_idx)
            if normalize_emb: normalize_emb_(embedding)
            self.embedding.weight.data = embedding

            if opt['fix_embeddings']:
                assert opt['tune_partial'] == 0
                for p in self.embedding.parameters():
                    p.requires_grad = False
            elif opt['tune_partial'] > 0:
                assert opt['tune_partial'] + 2 < embedding.size(0)
                fixed_embedding = embedding[opt['tune_partial'] + 2:]
                self.register_buffer('fixed_embedding', fixed_embedding)
                self.fixed_embedding = fixed_embedding
        else:  # random initialized
            self.embedding = nn.Embedding(opt['vocab_size'],
                                          opt['embedding_dim'],
                                          padding_idx=padding_idx)
        if opt['pos']:
            self.pos_embedding = nn.Embedding(opt['pos_size'], opt['pos_dim'])
            if normalize_emb: normalize_emb_(self.pos_embedding.weight.data)
        if opt['ner']:
            self.ner_embedding = nn.Embedding(opt['ner_size'], opt['ner_dim'])
            if normalize_emb: normalize_emb_(self.ner_embedding.weight.data)
        # Projection for attention weighted question
        if opt['use_qemb']:
            self.qemb_match = layers.SeqAttnMatch(3 * opt['embedding_dim'])
        if opt['use_cove']:
            self.cove_embedding = MTLSTM(n_vocab=embedding.size(0),
                                         vectors=embedding.clone())
            if not opt['fine_tune']:
                for p in self.cove_embedding.parameters():
                    p.requires_grad = False

        # Input size to RNN: word emb + question emb + manual features
        doc_input_size = opt['embedding_dim'] + opt['num_features']
        question_input_size = opt['embedding_dim']
        if opt['use_qemb']:
            doc_input_size += opt['embedding_dim']
        if opt['pos']:
            doc_input_size += opt['pos_dim']
        if opt['ner']:
            doc_input_size += opt['ner_dim']
        if opt['use_cove']:
            # for Cove
            doc_input_size += 2 * opt['embedding_dim']
            question_input_size += 2 * opt['embedding_dim']

        print('doc_input_size:', doc_input_size)
        self.attention_rnns = custom.AttentionRNN(
            opt,
            doc_input_size=doc_input_size,
            question_input_size=question_input_size,
            ratio=opt['reduction_ratio'])

        # Output sizes of rnn encoders
        doc_hidden_size = 2 * opt['hidden_size'] + opt['hidden_size'] // opt[
            'reduction_ratio']
        question_hidden_size = 2 * opt['hidden_size'] + opt[
            'hidden_size'] // opt['reduction_ratio']

        # Question merging
        if opt['question_merge'] not in ['avg', 'self_attn']:
            raise NotImplementedError('question_merge = %s' %
                                      opt['question_merge'])
        if opt['question_merge'] == 'self_attn':
            self.self_attn = layers.LinearSeqAttn(question_hidden_size)

        # Bilinear attention for span start/end
        self.start_attn = layers.BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
        )
        self.end_attn = layers.BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
        )

    def forward(self, x1, x1_f, x1_pos, x1_ner, x1_mask, x2, x2_mask, x1_order,
                x2_order):
        """Inputs:
        x1 = document word indices             [batch * len_d]
        x1_f = document word features indices  [batch * len_d * nfeat]
        x1_pos = document POS tags             [batch * len_d]
        x1_ner = document entity tags          [batch * len_d]
        x1_mask = document padding mask        [batch * len_d]
        x2 = question word indices             [batch * len_q]
        x2_mask = question padding mask        [batch * len_q]
        """
        # Embed both document and question
        x1_emb = self.embedding(x1)
        if self.opt['use_cove']:
            x1_emb_cove = self.cove_embedding(
                x1,
                torch.LongTensor(x1.size(0)).fill_(x1.size(1)).cuda())
        #x1_emb_order = self.embedding_order(x1_order)

        x2_emb = self.embedding(x2)
        if self.opt['use_cove']:
            x2_emb_cove = self.cove_embedding(
                x2,
                torch.LongTensor(x2.size(0)).fill_(x2.size(1)).cuda())
        #x2_emb += self.embedding_order(x2_order)

        if self.opt['dropout_emb'] > 0:
            x1_emb = nn.functional.dropout(x1_emb,
                                           p=self.opt['dropout_emb'],
                                           training=self.training)
            x2_emb = nn.functional.dropout(x2_emb,
                                           p=self.opt['dropout_emb'],
                                           training=self.training)

        x2_emb = torch.cat([x2_emb, x2_emb_cove], dim=2)
        x1_emb = torch.cat([x1_emb, x1_emb_cove], dim=2)

        drnn_input_list = [x1_emb, x1_f]
        # Add attention-weighted question representation
        if self.opt['use_qemb']:
            x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
            drnn_input_list.append(x2_weighted_emb)
        if self.opt['pos']:
            x1_pos_emb = self.pos_embedding(x1_pos)
            if self.opt['dropout_emb'] > 0:
                x1_pos_emb = nn.functional.dropout(x1_pos_emb,
                                                   p=self.opt['dropout_emb'],
                                                   training=self.training)
            drnn_input_list.append(x1_pos_emb)
        if self.opt['ner']:
            x1_ner_emb = self.ner_embedding(x1_ner)
            if self.opt['dropout_emb'] > 0:
                x1_ner_emb = nn.functional.dropout(x1_ner_emb,
                                                   p=self.opt['dropout_emb'],
                                                   training=self.training)
            drnn_input_list.append(x1_ner_emb)
        drnn_input = torch.cat(drnn_input_list, 2)
        #print('drnn_input:',drnn_input.size())
        # Encode document with RNN
        doc_hiddens, question_hiddens = self.attention_rnns(
            drnn_input, x1_mask, x2_emb, x2_mask)
        if self.opt['question_merge'] == 'avg':
            q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
        elif self.opt['question_merge'] == 'self_attn':
            q_merge_weights = self.self_attn(question_hiddens, x2_mask)
        question_hidden = layers.weighted_avg(question_hiddens,
                                              q_merge_weights)

        start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
        end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
        return start_scores, end_scores
Ejemplo n.º 3
0
class FusionNetReader(nn.Module):
    RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}

    def __init__(self, args):
        super(FusionNetReader, self).__init__()
        # Store config
        self.args = args

        # Word embeddings (+1 for padding)
        self.embedding = nn.Embedding(args.vocab_size,
                                      args.embedding_dim,
                                      padding_idx=0)
        if args.use_cove and args.embedding_dim == 300:
            # init cove_encoder without additional embeddings
            self.cove_encoder = MTLSTM()  # 300
            for p in self.cove_encoder.parameters():
                p.requires_grad = False

        if args.use_qemb:
            self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)

        # Input size to RNN: word emb + cove emb + manual features + question emb
        doc_input_size = args.embedding_dim + args.num_features
        question_input_size = args.embedding_dim
        if args.use_cove:
            doc_input_size += 2 * args.cove_embedding_dim
            question_input_size += 2 * args.cove_embedding_dim
        if args.use_qemb:
            doc_input_size += args.embedding_dim

        # Reading component (low-level layer)
        self.reading_low_level_doc_rnn = layers.StackedBRNN(
            input_size=doc_input_size,
            hidden_size=args.hidden_size,
            num_layers=1,
            dropout_rate=args.dropout_rnn,
            dropout_output=args.dropout_rnn_output,
            padding=args.rnn_padding
        )

        self.reading_low_level_question_rnn = layers.StackedBRNN(
            input_size=question_input_size,
            hidden_size=args.hidden_size,
            num_layers=1,
            dropout_rate=args.dropout_rnn,
            dropout_output=args.dropout_rnn_output,
            padding=args.rnn_padding
        )

        # Reading component (high-level layer)
        self.reading_high_level_doc_rnn = layers.StackedBRNN(
            input_size=args.hidden_size * 2,
            hidden_size=args.hidden_size,
            num_layers=1,
            dropout_rate=args.dropout_rnn,
            dropout_output=args.dropout_rnn_output,
            padding=args.rnn_padding
        )

        self.reading_high_level_question_rnn = layers.StackedBRNN(
            input_size=args.hidden_size * 2,
            hidden_size=args.hidden_size,
            num_layers=1,
            dropout_rate=args.dropout_rnn,
            dropout_output=args.dropout_rnn_output,
            padding=args.rnn_padding
        )

        # Question understanding component
        # input: [low_level_question, high_level_question]
        self.understanding_question_rnn = layers.StackedBRNN(
            input_size=args.hidden_size * 4,
            hidden_size=args.hidden_size,
            num_layers=1,
            dropout_rate=args.dropout_rnn,
            dropout_output=args.dropout_rnn_output,
            padding=args.rnn_padding
        )

        # [word_embedding, cove_embedding, low_level_doc_hidden, high_level_doc_hidden]
        history_of_word_size = args.embedding_dim + 2 * args.cove_embedding_dim + 4 * args.hidden_size

        # self.low_level_matrix_attention = MatrixAttention(SymmetricBilinearSimilarity(history_of_word_size,
        #                                                                               args.attention_size,
        #                                                                               F.relu))
        # self.high_level_matrix_attention = MatrixAttention(SymmetricBilinearSimilarity(history_of_word_size,
        #                                                                                args.attention_size,
        #                                                                                F.relu))
        # self.understanding_matrix_attention = MatrixAttention(SymmetricBilinearSimilarity(history_of_word_size,
        #                                                                                   args.attention_size,
        #                                                                                   F.relu))

        # self.low_level_matrix_attention = MatrixAttention(BilinearSimilarity(history_of_word_size,
        #                                                                      history_of_word_size))
        # self.high_level_matrix_attention = MatrixAttention(BilinearSimilarity(history_of_word_size,
        #                                                                       history_of_word_size))
        # self.understanding_matrix_attention = MatrixAttention(BilinearSimilarity(history_of_word_size,
        #                                                                          history_of_word_size))

        self.low_level_matrix_attention_layer = layers.SymBilinearAttnMatch(history_of_word_size,
                                                                            args.attention_size)
        self.high_level_matrix_attention_layer = layers.SymBilinearAttnMatch(history_of_word_size,
                                                                             args.attention_size)
        self.understanding_matrix_attention_layer = layers.SymBilinearAttnMatch(history_of_word_size,
                                                                                args.attention_size)

        # Multi-level rnn
        # input: [low_level_doc, high_level_doc, low_level_fusion_doc, high_level_fusion_doc,
        # understanding_level_question_fusion_doc]
        self.multi_level_rnn = layers.StackedBRNN(
            input_size=args.hidden_size * 2 * 5,
            hidden_size=args.hidden_size,
            num_layers=1,
            padding=args.rnn_padding
        )

        # [word_embedding, cove_embedding, low_level_doc_hidden, high_level_doc_hidden, low_level_doc_question_vector,
        # high_level_doc_question_vector, understanding_doc_question_vector, fa_multi_level_doc_hidden]
        history_of_doc_word_size = history_of_word_size + 4 * 2 * args.hidden_size

        # self.self_boosted_matrix_attention = MatrixAttention(SymmetricBilinearSimilarity(history_of_doc_word_size,
        #                                                                                  args.attention_size,
        #                                                                                  F.relu))

        self.self_boosted_matrix_attention_layer = layers.SymBilinearAttnMatch(history_of_doc_word_size,
                                                                               args.attention_size)

        #
        # self.self_boosted_matrix_attention = MatrixAttention(BilinearSimilarity(history_of_doc_word_size,
        #                                                                         history_of_doc_word_size))
        # Fully-Aware Self-Boosted fusion rnn
        # input: [fully_aware_encoded_doc(hidden state from last layer) ,self_boosted_fusion_doc]
        self.understanding_doc_rnn = layers.StackedBRNN(
            input_size=args.hidden_size * 2 * 2,
            hidden_size=args.hidden_size,
            num_layers=1,
            padding=args.rnn_padding
        )

        # Output sizes of rnn
        doc_hidden_size = 2 * args.hidden_size
        question_hidden_size = 2 * args.hidden_size
        if args.concat_rnn_layers:
            doc_hidden_size *= args.doc_layers
            question_hidden_size *= args.question_layers

        # Question merging
        self.question_self_attn = layers.LinearSeqAttn(question_hidden_size)

        self.start_attn = layers.BilinearSeqAttn(doc_hidden_size, question_hidden_size, log_normalize=False)

        self.start_gru = nn.GRU(doc_hidden_size, args.hidden_size * 2, batch_first=True)

        self.end_attn = layers.BilinearSeqAttn(doc_hidden_size, question_hidden_size, log_normalize=False)

    def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
        """Inputs:
        x1 = document word indices             [batch * len_d]
        x1_mask = document padding mask        [batch * len_d]
        x1_f = document word features indices  [batch * len_d * nfeat]
        x2 = question word indices             [batch * len_q]
        x2_mask = question padding mask        [batch * len_q]
        """
        # Embed both document and question
        x1_word_emb = self.embedding(x1)  # [batch, len_d, embedding_dim]
        x2_word_emb = self.embedding(x2)  # [batch, len_q, embedding_dim]

        x1_lengths = x1_mask.data.eq(0).long().sum(1).squeeze()  # batch
        x2_lengths = x2_mask.data.eq(0).long().sum(1).squeeze()  # batch

        x1_cove_emb = self.cove_encoder(x1_word_emb, x1_lengths)
        x2_cove_emb = self.cove_encoder(x2_word_emb, x2_lengths)

        x1_emb = torch.cat([x1_word_emb, x1_cove_emb], dim=-1)
        x2_emb = torch.cat([x2_word_emb, x2_cove_emb], dim=-1)

        # Dropout on embeddings
        if self.args.dropout_emb > 0:
            x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
                                           training=self.training)
            x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
                                           training=self.training)
        # Form document encoding inputs
        drnn_input = [x1_emb]

        # Add attention-weighted question representation
        if self.args.use_qemb:
            x2_weighted_emb = self.qemb_match(x1_word_emb, x2_word_emb, x2_mask)  # batch * len_d
            drnn_input.append(x2_weighted_emb)

        # Add manual features
        if self.args.num_features > 0:
            drnn_input.append(x1_f)

        # Encode document with RNN shape: [batch, len_d, 2*hidden_size]
        low_level_doc_hiddens = self.reading_low_level_doc_rnn(torch.cat(drnn_input, 2), x1_mask)
        low_level_question_hiddens = self.reading_low_level_question_rnn(x2_emb, x2_mask)

        # Encode question with RNN shape: [batch, len_q, 2*hidden_size]
        high_level_doc_hiddens = self.reading_high_level_doc_rnn(low_level_doc_hiddens, x1_mask)
        high_level_question_hiddens = self.reading_high_level_question_rnn(low_level_question_hiddens, x2_mask)

        # Encode low_level_question_hiddens and high_level_question_hiddens shape:[batch, len_q, 2*hidden_size]
        understanding_question_hiddens = self.understanding_question_rnn(torch.cat([low_level_question_hiddens,
                                                                                    high_level_question_hiddens], 2),
                                                                         x2_mask)

        # history of word shape:[batch, len_d, history_of_word_size]
        history_of_doc_word = torch.cat([x1_word_emb, x1_cove_emb, low_level_doc_hiddens, high_level_doc_hiddens]
                                        , dim=2)
        # history of word shape:[batch, len_q, history_of_word_size]
        history_of_question_word = torch.cat([x2_word_emb, x2_cove_emb, low_level_question_hiddens,
                                              low_level_question_hiddens], dim=2)
        # # high_level_doc_hiddens
        # # fully-aware multi-level attention
        # low_level_similarity = self.low_level_matrix_attention(history_of_doc_word, history_of_question_word)
        # high_level_similarity = self.high_level_matrix_attention(history_of_doc_word, history_of_question_word)
        # understanding_similarity = self.understanding_matrix_attention(history_of_doc_word, history_of_question_word)
        #
        # # shape: [batch, len_d, len_q]
        # low_level_norm_sim = util.last_dim_softmax(low_level_similarity, x2_mask)
        # high_level_norm_sim = util.last_dim_softmax(high_level_similarity, x2_mask)
        # understanding_norm_sim = util.last_dim_softmax(understanding_similarity, x2_mask)
        #
        # # shape: [batch, len_d, 2*hidden_size]
        # low_level_doc_question_vectors = util.weighted_sum(low_level_question_hiddens, low_level_norm_sim)
        # high_level_doc_question_vectors = util.weighted_sum(high_level_question_hiddens, high_level_norm_sim)
        # understanding_doc_question_vectors = util.weighted_sum(understanding_question_hiddens, understanding_norm_sim)

        low_level_doc_question_vectors = self.low_level_matrix_attention_layer(
            history_of_doc_word, history_of_question_word, x2_mask, low_level_question_hiddens)
        high_level_doc_question_vectors = self.high_level_matrix_attention_layer(
            history_of_doc_word, history_of_question_word, x2_mask, high_level_question_hiddens)
        understanding_doc_question_vectors = self.understanding_matrix_attention_layer(
            history_of_doc_word, history_of_question_word, x2_mask, understanding_question_hiddens)


        # Encode multi-level hiddens and vectors
        fa_multi_level_doc_hiddens = self.multi_level_rnn(torch.cat([low_level_doc_hiddens, high_level_doc_hiddens,
                                                                     low_level_doc_question_vectors,
                                                                     high_level_doc_question_vectors,
                                                                     understanding_doc_question_vectors], dim=2),
                                                          x1_mask)
        # fa_multi_level_doc_hiddens = low_level_doc_question_vectors
        #
        history_of_doc_word2 = torch.cat([x1_word_emb, x1_cove_emb, low_level_doc_hiddens, high_level_doc_hiddens,
                                          low_level_doc_question_vectors, high_level_doc_question_vectors,
                                          understanding_doc_question_vectors, fa_multi_level_doc_hiddens], dim=2)

        # # shape: [batch, len_d, len_d]
        # self_boosted_similarity = self.self_boosted_matrix_attention(history_of_doc_word2, history_of_doc_word2)
        #
        # # shape: [batch, len_d, len_d]
        # self_boosted_norm_sim = util.last_dim_softmax(self_boosted_similarity, x1_mask)
        #
        # # shape: [batch, len_d, 2*hidden_size]
        # self_boosted_vectors = util.weighted_sum(fa_multi_level_doc_hiddens, self_boosted_norm_sim)

        self_boosted_vectors = self.self_boosted_matrix_attention_layer(
            history_of_doc_word2, history_of_doc_word2, x1_mask, fa_multi_level_doc_hiddens)


        # Encode vectors and hiddens
        # shape: [batch, len_d, 2*hidden_size]
        understanding_doc_hiddens = self.understanding_doc_rnn(torch.cat([fa_multi_level_doc_hiddens,
                                                                          self_boosted_vectors], dim=2), x1_mask)

        # understanding_doc_hiddens = fa_multi_level_doc_hiddens

        # shape: [batch, len_q]
        q_merge_weights = self.question_self_attn(understanding_question_hiddens, x2_mask)
        # shape: [batch, 2*hidden_size]
        question_hidden = layers.weighted_avg(understanding_question_hiddens, q_merge_weights)

        # Predict start and end positions
        # shape: [batch, len_d]  SOFTMAX NOT LOG_SOFTMAX
        start_scores = self.start_attn(understanding_doc_hiddens, question_hidden, x1_mask)
        # shape: [batch, 2*hidden_size]
        gru_input = layers.weighted_avg(understanding_doc_hiddens, start_scores)
        # shape: [batch, 1, 2*hidden_size]
        gru_input = gru_input.unsqueeze(1)
        # shape: [1, batch, 2*hidden_size]
        question_hidden = question_hidden.unsqueeze(0)
        _, memory_hidden = self.start_gru(gru_input, question_hidden)
        # shape: [batch, 2*hidden_size]
        memory_hidden = memory_hidden.squeeze(0)
        # shape: [batch, len_d]
        end_scores = self.end_attn(understanding_doc_hiddens, memory_hidden, x1_mask)
        # log start_scores
        if self.training:
            start_scores = torch.log(start_scores.add(1e-8))
            end_scores = torch.log(end_scores.add(1e-8))
        return start_scores, end_scores