コード例 #1
0
ファイル: net.py プロジェクト: keep-smile-001/opentqa
    def forward(self, que_ix, opt_ix, dia, dia_node_ix, ins_dia, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """
        batch_size = que_ix.shape[0]
        que_mask = make_mask(que_ix.unsqueeze(2))
        que_feat = self.embedding(que_ix)
        que_feat, _ = self.encode_lang(que_feat)
        que_feat = self.flat(que_feat, que_mask)

        opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token)
        opt_mask = make_mask(opt_ix.unsqueeze(2))

        opt_feat = self.embedding(opt_ix)
        opt_feat, _ = self.encode_lang(opt_feat)
        opt_feat = self.flat(opt_feat, opt_mask)
        opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1)

        dia_feat = self.simclr(dia)
        dia_feat = dia_feat.reshape(batch_size, -1)

        dia_feat_dim = dia_feat.shape[-1]
        dia_feat = dia_feat.reshape(batch_size, 1, dia_feat_dim)
        fusion_feat = self.backbone(dia_feat, que_feat) #[8,4096]  BAN得到的是[b,15,1024]

        # fusion_feat = self.flatten(fusion_feat.sum(1))
        fusion_feat = fusion_feat.repeat(1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans, -1)
        fuse_opt_feat = torch.cat((fusion_feat, opt_feat), dim=-1) #(8,4,4096)(8,4,1024)
        proj_feat = self.classifer(fuse_opt_feat).squeeze(-1)

        return proj_feat
コード例 #2
0
    def forward(self, que_ix, opt_ix, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """
        batch_size = que_ix.shape[0]
        que_feat = self.embedding(que_ix)
        que_feat, _ = self.lang(que_feat)

        opt_num = make_mask(opt_ix)  # to get the actual number of options
        opt_feat = self.embedding(opt_ix)
        opt_feat = opt_feat.reshape(
            (-1, self.cfgs.max_opt_token, self.cfgs.word_emb_size))
        _, opt_feat = self.lang(opt_feat)
        opt_feat = opt_feat.transpose(1, 0).reshape(
            self.cfgs.max_ndq_ans * batch_size,
            -1).reshape(batch_size, self.cfgs.max_ndq_ans, -1)
        cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token)
        cp_feat = self.embedding(cp_ix)
        cp_feat, _ = self.lang(cp_feat)
        cp_feat = cp_feat.reshape(
            batch_size, self.cfgs.max_sent * self.cfgs.max_sent_token, -1)

        fusion_feat = self.backbone(que_feat, cp_feat)

        fusion_feat = self.flatten(fusion_feat.sum(1))
        fusion_feat = fusion_feat.repeat(1, self.cfgs.max_ndq_ans).reshape(
            batch_size, self.cfgs.max_ndq_ans, -1)

        proj_feat = self.classifer(fusion_feat, opt_feat)
        proj_feat = proj_feat.masked_fill(opt_num, -1e9)
        return proj_feat, torch.sum((opt_num == 0), dim=-1)
コード例 #3
0
ファイル: net.py プロジェクト: keep-smile-001/opentqa
    def forward(self, que_ix, opt_ix, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """
        batch_size = que_ix.shape[0]
        lang_feat_mask = self.make_mask(que_ix.unsqueeze(2))
        que_feat = self.embedding(que_ix)
        que_feat, _ = self.lstm(que_feat)

        opt_num = make_mask(opt_ix)  # to get the actual number of options
        opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token)
        opt_feat_mask = self.make_mask(opt_ix.unsqueeze(2))
        opt_feat = self.embedding(opt_ix)
        opt_feat, _ = self.lstm(opt_feat)

        opt_feat = self.attflat_lang(opt_feat, opt_feat_mask)
        opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_ndq_ans,
                                    -1)  # opt_feat:   [batch*4,1024]

        cp_ix = cp_ix.reshape(batch_size, -1)
        cp_mask = make_mask(cp_ix.unsqueeze(2))
        cp_feat = self.embedding(cp_ix)
        cp_feat, _ = self.lstm(cp_feat)

        # Backbone Framework
        lang_feat, cp_feat = self.backbone(que_feat, cp_feat, lang_feat_mask,
                                           cp_mask)
        lang_feat = self.attflat_lang(lang_feat, lang_feat_mask)

        cp_feat = self.attflat_img(cp_feat, cp_mask)

        proj_feat = lang_feat + cp_feat
        proj_feat = self.proj_norm(proj_feat)

        proj_feat = proj_feat.repeat(1, self.cfgs.max_ndq_ans).reshape(
            batch_size, self.cfgs.max_ndq_ans, self.cfgs.flat_out_size)
        fuse_feat = torch.cat((proj_feat, opt_feat), dim=-1)
        proj_feat = self.proj(fuse_feat).squeeze(-1)

        proj_feat = proj_feat.masked_fill(opt_num, -1e9)
        return proj_feat, torch.sum((opt_num == 0), dim=-1)
コード例 #4
0
    def forward(self, que_ix, opt_ix, dia, ins_dia, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """
        batch_size = que_ix.shape[0]
        que_feat = self.embedding(que_ix)
        que_feat, que_hidden = self.rnn_qo(que_feat)

        # opt_feat = self.embedding(opt_ix)
        # opt_feat = opt_feat.reshape((-1, self.cfgs.max_opt_token, self.cfgs.word_emb_size))
        # _, opt_feat = self.rnn_qo(opt_feat)
        # opt_feat = opt_feat.transpose(1, 0).reshape(self.cfgs.max_dq_ans * batch_size, -1).reshape(batch_size,
        #                                                                                         self.cfgs.max_dq_ans,
        #                                                                                         -1)
        opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token)
        opt_mask = make_mask(opt_ix.unsqueeze(2))
        opt_feat = self.embedding(opt_ix)
        opt_feat, _ = self.rnn_qo(opt_feat)
        opt_feat = self.flatten_opt(opt_feat, opt_mask)
        opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1)

        cp_feat = self.embedding(cp_ix)
        cp_feat, _ = self.rnn_cp(
            cp_feat.reshape(batch_size, -1, self.cfgs.word_emb_size))

        dia_feat = self.simclr(dia)
        dia_feat = dia_feat.reshape(batch_size, -1)

        # if self.cfgs.use_ins_dia in 'True':
        #     ins_dia_feat = ins_dia.reshape(-1, 3, self.cfgs.input_size, self.cfgs.input_size)
        #     ins_dia_feat = self.simclr(ins_dia_feat)
        #     ins_dia_feat = ins_dia_feat.reshape(batch_size, -1, self.cfgs.dia_feat_size)
        #     related_ins_dia_feat = get_related_diagram(dia_feat, ins_dia_feat)
        #     dia_feat = (dia_feat + related_ins_dia_feat) / 2.0

        # fusion_feat = self.backbone(que_opt_feat, dia_feat, cp_feat)
        fusion_feat = self.forward_with_divide_and_rule(
            que_feat, dia_feat, ins_dia, cp_feat, que_hidden)
        fusion_feat = self.flatten(fusion_feat.sum(1))
        fusion_feat = fusion_feat.repeat(1, self.cfgs.max_dq_ans).reshape(
            batch_size, self.cfgs.max_dq_ans, -1)
        proj_feat = self.classifer1(fusion_feat, opt_feat)

        return proj_feat
コード例 #5
0
    def forward(self, que_ix, opt_ix, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """

        batch_size = que_ix.shape[0]
        que_mask = make_mask(que_ix.unsqueeze(2))
        que_feat = self.embedding(que_ix)
        que_feat, _ = self.encode_lang(que_feat)
        que_feat = que_feat.reshape(batch_size, -1, self.hid_dim)
        ##print(que_feat.size())
        opt_num = make_mask(opt_ix)
        opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token)

        opt_mask = make_mask(opt_ix.unsqueeze(2))

        opt_feat = self.embedding(opt_ix)
        _, opt_feat = self.encode_lang(opt_feat)
        #opt_feat = self.flat(opt_feat, opt_mask)
        #print(_.size(),opt_feat.size())
        opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_ndq_ans, -1)

        cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token)
        cp_feat = self.embedding(cp_ix)
        cp_feat, _ = self.encode_lang2(cp_feat)
        cp_feat = cp_feat.reshape(
            batch_size, self.cfgs.max_sent * self.cfgs.max_sent_token, -1)

        batch_size, img_num, obj_num, feat_size = cp_feat.size(
        )[0], 1, cp_feat[1], cp_feat.size()[2]
        #print("-----------------------------\n",cp_feat.size())
        assert img_num == 1 and feat_size == 2048
        cp_feat = cp_feat.reshape(batch_size, -1, feat_size)

        output_lang, output_img, output_cross = self.bert_encoder(
            sents=que_feat, feats=cp_feat)
        # output_cross = output_cross.view(-1, self.hid_dim*2) ## original
        output_cross = output_cross.view(-1, self.hid_dim)
        ##print(output_lang.size(), output_img.size(), output_cross.size())

        #### new experiment for relationship
        relate_lang_stack_1 = output_lang.view(output_lang.size()[0], 1,
                                               output_lang.size()[1],
                                               output_lang.size()[2])
        relate_lang_stack_2 = output_lang.view(output_lang.size()[0],
                                               output_lang.size()[1], 1,
                                               output_lang.size()[2])
        # relate_lang_stack = relate_lang_stack_1 + relate_lang_stack_2 ## [64, 20, 20, 768]
        relate_lang_stack_1 = relate_lang_stack_1.repeat(
            1,
            output_lang.size()[1], 1, 1
        )  ## [64, 20, 20, 768] second dim repeat 10 times, others not change
        relate_lang_stack_2 = relate_lang_stack_2.repeat(
            1, 1,
            output_lang.size()[1], 1
        )  ## [64, 20, 20, 768] third dim repeat 10 times, others not change
        relate_lang_stack = torch.cat(
            (relate_lang_stack_1, relate_lang_stack_2),
            3)  ## [64, 20, 20, 768*2]

        relate_lang_stack = relate_lang_stack.view(-1,
                                                   output_lang.size()[2] * 2)
        relate_lang_stack = self.lang_2_to_1(relate_lang_stack)
        relate_lang_stack = relate_lang_stack.view(output_lang.size()[0],
                                                   output_lang.size()[1],
                                                   output_lang.size()[1],
                                                   output_lang.size()[2])

        relate_img_stack_1 = output_img.view(
            output_img.size()[0], 1,
            output_img.size()[1],
            output_img.size()[2])  #[16,1,1,768]
        relate_img_stack_2 = output_img.view(
            output_img.size()[0],
            output_img.size()[1], 1,
            output_img.size()[2])  #[16,1,1,768]
        ##print("r_i_s",relate_img_stack_1.size(),relate_img_stack_2.size())
        relate_img_stack = relate_img_stack_1 + relate_img_stack_2  ## [16, 1, 1, 768]
        # relate_img_stack_1 = relate_img_stack_1.repeat(1,output_img.size()[1],1,1)  ## [64, 20, 20, 768] second dim repeat 10 times, others not change
        # relate_img_stack_2 = relate_img_stack_2.repeat(1,1,output_img.size()[1],1)  ## [64, 20, 20, 768] third dim repeat 10 times, others not change
        # relate_img_stack = torch.cat((relate_img_stack_1, relate_img_stack_2), 3)

        # relate_img_stack = relate_img_stack.view(-1, output_lang.size()[2]*2)
        # relate_img_stack = self.lang_2_to_1(relate_img_stack)
        # relate_img_stack = relate_img_stack.view(output_img.size()[0], output_img.size()[1], output_img.size()[1], output_img.size()[2])

        relate_lang_stack = relate_lang_stack.view(
            relate_lang_stack.size()[0],
            relate_lang_stack.size()[1] * relate_lang_stack.size()[2],
            relate_lang_stack.size()[3])  ## [64, 400, 768] or 768*2
        relate_img_stack = relate_img_stack.view(
            relate_img_stack.size()[0],
            relate_img_stack.size()[1] * relate_img_stack.size()[2],
            relate_img_stack.size()[3])  ## [64, 1296, 768] or 768*2
        ##print("relate_i_s",relate_img_stack.size())
        ### a beautiful way
        relate_lang_ind = torch.tril_indices(output_lang.size()[1],
                                             output_lang.size()[1], -1).cuda(0)
        relate_lang_ind[1] = relate_lang_ind[1] * output_lang.size()[1]
        relate_lang_ind = relate_lang_ind.sum(0)
        relate_lang_stack = relate_lang_stack.index_select(
            1, relate_lang_ind)  ## [64, 190, 768] or 768*2

        relate_img_ind = torch.tril_indices(output_img.size()[1],
                                            output_img.size()[1], -1).cuda(0)
        #print("rii", relate_img_ind)
        relate_img_ind[1] = relate_img_ind[1] * output_img.size()[1]
        relate_img_ind = relate_img_ind.sum(0)
        #print("rii",relate_img_ind.size())
        relate_img_stack = relate_img_stack.index_select(
            1, relate_img_ind)  ## [64, 630, 768] or 768*2
        #print("ris", relate_img_stack.size())
        ## reshape the relate_lang_stack and relate_img_stack
        tmp_lang_stack = relate_lang_stack.view(-1, self.hid_dim)  # sum
        tmp_img_stack = relate_img_stack.view(-1, self.hid_dim)  # sum
        # tmp_lang_stack = relate_lang_stack.view(-1, self.hid_dim*2)   # cat
        # tmp_img_stack = relate_img_stack.view(-1, self.hid_dim*2)     # cat

        lang_candidate_relat_score = self.lang_relation(tmp_lang_stack)
        img_candidate_relat_score = self.img_relation(tmp_img_stack)

        lang_candidate_relat_score = lang_candidate_relat_score.view(
            output_lang.size()[0],
            relate_lang_stack.size()[1])  ##(64, 190)
        img_candidate_relat_score = img_candidate_relat_score.view(
            output_img.size()[0],
            relate_img_stack.size()[1])  ## (64,630)
        ##print("icrs",img_candidate_relat_score)
        ##print("-1")
        #time.sleep(10)
        _, topk_lang_index = torch.topk(lang_candidate_relat_score,
                                        self.top_k_value,
                                        sorted=False)  ##(64, 10)
        _, topk_img_index = torch.topk(img_candidate_relat_score,
                                       self.top_k_value,
                                       sorted=False)  ##(64, 10)
        ##print("0")
        #time.sleep(10)
        list_lang_relat = []
        list_img_relat = []
        for i in range(0, output_lang.size()[0]):
            tmp = torch.index_select(relate_lang_stack[i], 0,
                                     topk_lang_index[i])  ## [10, 768] or 768*2
            list_lang_relat.append(tmp)
        #print("1")
        #time.sleep(10)
        for i in range(0, output_img.size()[0]):
            tmp = torch.index_select(relate_img_stack[i], 0, topk_img_index[i])
            #tmp = relate_img_stack[i]  ## [10, 768] or 768*2
            list_img_relat.append(tmp)
            #print("tmp",tmp.size())
        #print("2")
        #time.sleep(10)
        #raise AssertionError
        #print("l-i-r", list_img_relat)
        lang_relat = torch.cat(list_lang_relat, 0)  ## [640, 768] or 768*2
        img_relat = torch.cat(list_img_relat, 0)  ## [640, 768] or 768*2
        #print("l-i-r", lang_relat.size(), img_relat.size())
        lang_relat = lang_relat.view(output_lang.size()[0], -1,
                                     self.hid_dim)  ## [64, 10, 768] or 768*2
        img_relat = img_relat.view(output_img.size()[0], -1,
                                   self.hid_dim)  ## [64, 10, 768] or 768*2
        # lang_relat = lang_relat.view(output_lang.size()[0], -1, self.hid_dim*2) ## [64, 10, 768] or 768*2
        # img_relat = img_relat.view(output_img.size()[0], -1, self.hid_dim*2) ## [64, 10, 768] or 768*2
        #print("l-i-r",lang_relat.size(),img_relat.size())

        relate_cross = torch.einsum('bld,brd->blr',
                                    F.normalize(lang_relat, p=2, dim=-1),
                                    F.normalize(img_relat, p=2, dim=-1))
        #print("rc", relate_cross.size())
        relate_cross = relate_cross.view(-1, 1,
                                         relate_cross.size()[1],
                                         relate_cross.size()[2])
        realte_conv_1 = self.rel_pool1(F.relu(self.rel_conv1(relate_cross)))
        # realte_conv_2 = self.rel_pool2(F.relu(self.rel_conv2(realte_conv_1)))
        #print("rcsss", realte_conv_1.size())
        relate_fc1 = F.relu(self.rel_fc1(realte_conv_1.view(-1, 32 * 4 * 4)))
        ##print(relate_fc1.size())
        relate_fc1 = relate_fc1.view(-1, self.hid_dim)
        logit4 = self.logit_fc4(relate_fc1)

        #### new experiment for cross modality
        output_cross = output_cross.view(-1, output_cross.size()[1])
        # #print("oc:",output_cross.size())
        cross_tuple = torch.split(output_cross,
                                  output_cross.size()[1] // 2,
                                  dim=1)
        ##print("ct:",cross_tuple[0].size(),cross_tuple[1].size())
        cross1 = cross_tuple[0].view(output_cross.size()[0], -1, 64)
        cross2 = cross_tuple[1].view(output_cross.size()[0], -1, 64)
        ##print("cross1",cross1.size(),cross2.size())
        ##print("c",output_cross.size())

        # #print(F.normalize(cross1,p=2,dim=-1).size())
        cross_1_2 = torch.einsum('bld,brd->blr',
                                 F.normalize(cross1, p=2, dim=-1),
                                 F.normalize(cross2, p=2, dim=-1))
        ##print("oc:", cross_1_2.size())
        cross_1_2 = cross_1_2.view(-1, 1,
                                   cross_1_2.size()[1],
                                   cross_1_2.size()[2])
        cross_conv_1 = self.cross_pool1(F.relu(self.cross_conv1(cross_1_2)))
        # cross_conv_2 = self.cross_pool2(F.relu(self.cross_conv2(cross_conv_1)))
        '''
        #### new experiment for lang and two images
        cross_img_sen = torch.einsum(
            'bld,brd->blr',
            F.normalize(output_lang, p=2, dim=-1),
            F.normalize(output_img, p=2, dim=-1)
        )

        cross_img_sen = cross_img_sen.view(-1, 1, cross_img_sen.size()[1], cross_img_sen.size()[2])
        entity_conv_1 = self.lang_pool1(F.relu(self.lang_conv1(cross_img_sen)))
        entity_conv_2 = self.lang_pool2(F.relu(self.lang_conv2(entity_conv_1)))

        ### new experiment for two images
        image_2_together = output_img.view(-1, output_img.size()[1], self.hid_dim * 2)
        # #print(image_2_together.size())
        images = torch.split(image_2_together, self.hid_dim // 2, dim=2)
        # #print(images[0].size(), images[1].size())
        
        image1 = images[0]
        image2 = images[1]
        '''
        '''
        cross_img_img = torch.einsum(
            'bld,brd->blr',
            F.normalize(image1, p=2, dim=-1),
            F.normalize(image2, p=2, dim=-1)
        )
        cross_img_img = cross_img_img.view(-1, 1, cross_img_img.size()[1], cross_img_img.size()[2])
        cross_img_conv_1 = self.img_pool1(F.relu(self.img_conv1(cross_img_img)))
        cross_img_conv_2 = self.img_pool2(F.relu(self.img_conv2(cross_img_conv_1)))
        # #print(cross_img_conv_2.size())

        img_fc1 = F.relu(self.img_fc1(cross_img_conv_2.view(-1, 32*7*7)))
        img_fc1 = img_fc1.view(-1, self.hid_dim)
        logit3 = self.logit_fc3(img_fc1)
        '''
        '''
        entity_fc1 = F.relu(self.lang_fc1(entity_conv_2.view(-1, 32 * 3 * 7)))
        entity_fc1 = entity_fc1.view(-1, self.hid_dim * 2)
        logit2 = self.logit_fc2(entity_fc1)
        '''

        ##print('cc1',cross_conv_1.size())
        cross_fc1 = F.relu(self.cross_fc1(cross_conv_1.view(-1, 32 * 2 * 2)))
        cross_fc1 = cross_fc1.view(-1, self.hid_dim)
        logit1 = self.logit_fc1(cross_fc1)
        # #print(logit1.size(),logit2.size(),logit3.size(),logit4.size())

        #cross_logit = torch.cat((logit1, logit2, logit3, logit4), 1)
        cross_logit = torch.cat((logit1, logit4), 1)
        #print("final:",logit1.size(),logit4.size())
        ##print("all:",cross_logit.size())
        cross_logit = cross_logit.repeat(1, self.cfgs.max_ndq_ans).reshape(
            batch_size, self.cfgs.max_ndq_ans, -1)
        #print("after all:", cross_logit.size(),opt_feat.size())
        cross_logit = torch.cat((cross_logit, opt_feat), dim=-1)
        #print("?",cross_logit.size())
        logit = self.final_classifier(cross_logit).squeeze(-1)

        #print(opt_num.size())
        logit = logit.masked_fill(opt_num, -1e9)
        return logit, torch.sum((opt_num == 0), dim=-1)
コード例 #6
0
ファイル: net.py プロジェクト: keep-smile-001/opentqa
    def forward(self, que_ix, opt_ix, dia, ins_dia, cp_ix):
        """
        :param que_ix: the index of questions
        :param opt_ix: the index of options
        :param dia: the diagram corresponding the above question
        :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question
        :param cp_ix: the closest paragraph that is extracted by TF-IDF method
        """

        batch_size = que_ix.shape[0]
        que_mask = make_mask(que_ix.unsqueeze(2))
        que_feat = self.embedding(que_ix)
        que_feat, _ = self.encode_lang(que_feat)

        opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token)
        opt_mask = make_mask(opt_ix.unsqueeze(2))

        opt_feat = self.embedding(opt_ix)
        opt_feat, _ = self.encode_lang(opt_feat)
        opt_feat = self.flat(opt_feat, opt_mask)
        opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1)

        cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token)
        cp_mask = make_mask(cp_ix.unsqueeze(2))
        cp_feat = self.embedding(cp_ix)
        cp_feat, _ = self.encode_lang(cp_feat)
        cp_feat = self.flat(cp_feat, cp_mask)
        cp_feat = cp_feat.reshape(batch_size, self.cfgs.max_sent, -1)
        span_feat = get_span_feat(cp_feat, self.cfgs.span_width)
        cp_sent_mask = make_mask(cp_mask == False).reshape(batch_size, -1)
        cp_sent_mask = cp_sent_mask.unsqueeze(-1).repeat(
            1, 1, self.cfgs.span_width).reshape(batch_size, -1)

        dia_feat = self.simclr(dia)

        # compute the entropy of questions
        flattened_que_feat = self.flat(que_feat, que_mask)
        que_entropy = - self.sigmoid(self.score_lang(flattened_que_feat)) * torch.log2(
            self.sigmoid(self.score_lang(flattened_que_feat))) - \
                      (1 - self.sigmoid(self.score_lang(flattened_que_feat))) * torch.log2(
            1 - self.sigmoid(self.score_lang(flattened_que_feat)))

        # compute the conditional entropy of questions given candidate spans
        span_num = span_feat.shape[1]
        flattened_que_feat_expand = flattened_que_feat.repeat(
            1, span_num).reshape(batch_size, span_num, -1)
        que_with_evi_feat = torch.cat((flattened_que_feat_expand, span_feat),
                                      dim=-1)
        que_with_evi_feat = self.pool3(que_with_evi_feat)
        span_feat = self.pool2(span_feat)

        evi_logit = self.sigmoid(self.score_lang(span_feat))
        que_cond_ent = evi_logit * (
            -self.sigmoid(self.score_lang(que_with_evi_feat)) *
            torch.log2(self.sigmoid(self.score_lang(que_with_evi_feat))) -
            (1 - self.sigmoid(self.score_lang(que_with_evi_feat))) *
            torch.log2(1 - self.sigmoid(self.score_lang(que_with_evi_feat))))

        # compute the information gains of questions given evidence
        inf_gain = que_entropy - que_cond_ent.squeeze(-1)
        evi_ix = get_ix(inf_gain, cp_sent_mask)
        evi_feat = torch.cat(
            [span_feat[i][int(ix)] for i, ix in enumerate(evi_ix)],
            dim=0).reshape(batch_size, -1)
        evi_feat = evi_feat.repeat(1, self.cfgs.max_dq_ans).reshape(
            batch_size, self.cfgs.max_dq_ans, -1)

        # repeat flattened question features
        flattened_que_feat = flattened_que_feat.repeat(
            1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans,
                                             -1)

        # use ban to fuse features of questions and diagrams
        fuse_que_dia = self.backbone(que_feat, dia_feat)
        fuse_que_dia = fuse_que_dia.mean(1).repeat(
            1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans,
                                             -1)

        # fuse features of evidence and options
        fuse_evi_opt = evi_feat * opt_feat

        # fuse features of questions and options
        fuse_que_opt = flattened_que_feat * opt_feat

        # fuse features of questions and evidence
        fuse_que_evi = flattened_que_feat * evi_feat

        # fuse features of questions, options, diagrams and evidence
        fuse_all = flattened_que_feat * opt_feat * fuse_que_dia

        dia_feat = dia_feat.repeat(1, self.cfgs.max_dq_ans).reshape(
            batch_size, self.cfgs.max_dq_ans, -1)

        fuse_feat = torch.cat(
            (flattened_que_feat, dia_feat, opt_feat, evi_feat, fuse_evi_opt,
             fuse_que_opt, fuse_que_evi, fuse_all),
            dim=-1)

        proj_feat = self.classifer(fuse_feat).squeeze(-1)
        return proj_feat