def forward(self, que_ix, opt_ix, dia, dia_node_ix, ins_dia, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] que_mask = make_mask(que_ix.unsqueeze(2)) que_feat = self.embedding(que_ix) que_feat, _ = self.encode_lang(que_feat) que_feat = self.flat(que_feat, que_mask) opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token) opt_mask = make_mask(opt_ix.unsqueeze(2)) opt_feat = self.embedding(opt_ix) opt_feat, _ = self.encode_lang(opt_feat) opt_feat = self.flat(opt_feat, opt_mask) opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1) dia_feat = self.simclr(dia) dia_feat = dia_feat.reshape(batch_size, -1) dia_feat_dim = dia_feat.shape[-1] dia_feat = dia_feat.reshape(batch_size, 1, dia_feat_dim) fusion_feat = self.backbone(dia_feat, que_feat) #[8,4096] BAN得到的是[b,15,1024] # fusion_feat = self.flatten(fusion_feat.sum(1)) fusion_feat = fusion_feat.repeat(1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans, -1) fuse_opt_feat = torch.cat((fusion_feat, opt_feat), dim=-1) #(8,4,4096)(8,4,1024) proj_feat = self.classifer(fuse_opt_feat).squeeze(-1) return proj_feat
def forward(self, que_ix, opt_ix, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] que_feat = self.embedding(que_ix) que_feat, _ = self.lang(que_feat) opt_num = make_mask(opt_ix) # to get the actual number of options opt_feat = self.embedding(opt_ix) opt_feat = opt_feat.reshape( (-1, self.cfgs.max_opt_token, self.cfgs.word_emb_size)) _, opt_feat = self.lang(opt_feat) opt_feat = opt_feat.transpose(1, 0).reshape( self.cfgs.max_ndq_ans * batch_size, -1).reshape(batch_size, self.cfgs.max_ndq_ans, -1) cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token) cp_feat = self.embedding(cp_ix) cp_feat, _ = self.lang(cp_feat) cp_feat = cp_feat.reshape( batch_size, self.cfgs.max_sent * self.cfgs.max_sent_token, -1) fusion_feat = self.backbone(que_feat, cp_feat) fusion_feat = self.flatten(fusion_feat.sum(1)) fusion_feat = fusion_feat.repeat(1, self.cfgs.max_ndq_ans).reshape( batch_size, self.cfgs.max_ndq_ans, -1) proj_feat = self.classifer(fusion_feat, opt_feat) proj_feat = proj_feat.masked_fill(opt_num, -1e9) return proj_feat, torch.sum((opt_num == 0), dim=-1)
def forward(self, que_ix, opt_ix, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] lang_feat_mask = self.make_mask(que_ix.unsqueeze(2)) que_feat = self.embedding(que_ix) que_feat, _ = self.lstm(que_feat) opt_num = make_mask(opt_ix) # to get the actual number of options opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token) opt_feat_mask = self.make_mask(opt_ix.unsqueeze(2)) opt_feat = self.embedding(opt_ix) opt_feat, _ = self.lstm(opt_feat) opt_feat = self.attflat_lang(opt_feat, opt_feat_mask) opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_ndq_ans, -1) # opt_feat: [batch*4,1024] cp_ix = cp_ix.reshape(batch_size, -1) cp_mask = make_mask(cp_ix.unsqueeze(2)) cp_feat = self.embedding(cp_ix) cp_feat, _ = self.lstm(cp_feat) # Backbone Framework lang_feat, cp_feat = self.backbone(que_feat, cp_feat, lang_feat_mask, cp_mask) lang_feat = self.attflat_lang(lang_feat, lang_feat_mask) cp_feat = self.attflat_img(cp_feat, cp_mask) proj_feat = lang_feat + cp_feat proj_feat = self.proj_norm(proj_feat) proj_feat = proj_feat.repeat(1, self.cfgs.max_ndq_ans).reshape( batch_size, self.cfgs.max_ndq_ans, self.cfgs.flat_out_size) fuse_feat = torch.cat((proj_feat, opt_feat), dim=-1) proj_feat = self.proj(fuse_feat).squeeze(-1) proj_feat = proj_feat.masked_fill(opt_num, -1e9) return proj_feat, torch.sum((opt_num == 0), dim=-1)
def forward(self, que_ix, opt_ix, dia, ins_dia, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] que_feat = self.embedding(que_ix) que_feat, que_hidden = self.rnn_qo(que_feat) # opt_feat = self.embedding(opt_ix) # opt_feat = opt_feat.reshape((-1, self.cfgs.max_opt_token, self.cfgs.word_emb_size)) # _, opt_feat = self.rnn_qo(opt_feat) # opt_feat = opt_feat.transpose(1, 0).reshape(self.cfgs.max_dq_ans * batch_size, -1).reshape(batch_size, # self.cfgs.max_dq_ans, # -1) opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token) opt_mask = make_mask(opt_ix.unsqueeze(2)) opt_feat = self.embedding(opt_ix) opt_feat, _ = self.rnn_qo(opt_feat) opt_feat = self.flatten_opt(opt_feat, opt_mask) opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1) cp_feat = self.embedding(cp_ix) cp_feat, _ = self.rnn_cp( cp_feat.reshape(batch_size, -1, self.cfgs.word_emb_size)) dia_feat = self.simclr(dia) dia_feat = dia_feat.reshape(batch_size, -1) # if self.cfgs.use_ins_dia in 'True': # ins_dia_feat = ins_dia.reshape(-1, 3, self.cfgs.input_size, self.cfgs.input_size) # ins_dia_feat = self.simclr(ins_dia_feat) # ins_dia_feat = ins_dia_feat.reshape(batch_size, -1, self.cfgs.dia_feat_size) # related_ins_dia_feat = get_related_diagram(dia_feat, ins_dia_feat) # dia_feat = (dia_feat + related_ins_dia_feat) / 2.0 # fusion_feat = self.backbone(que_opt_feat, dia_feat, cp_feat) fusion_feat = self.forward_with_divide_and_rule( que_feat, dia_feat, ins_dia, cp_feat, que_hidden) fusion_feat = self.flatten(fusion_feat.sum(1)) fusion_feat = fusion_feat.repeat(1, self.cfgs.max_dq_ans).reshape( batch_size, self.cfgs.max_dq_ans, -1) proj_feat = self.classifer1(fusion_feat, opt_feat) return proj_feat
def forward(self, que_ix, opt_ix, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] que_mask = make_mask(que_ix.unsqueeze(2)) que_feat = self.embedding(que_ix) que_feat, _ = self.encode_lang(que_feat) que_feat = que_feat.reshape(batch_size, -1, self.hid_dim) ##print(que_feat.size()) opt_num = make_mask(opt_ix) opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token) opt_mask = make_mask(opt_ix.unsqueeze(2)) opt_feat = self.embedding(opt_ix) _, opt_feat = self.encode_lang(opt_feat) #opt_feat = self.flat(opt_feat, opt_mask) #print(_.size(),opt_feat.size()) opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_ndq_ans, -1) cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token) cp_feat = self.embedding(cp_ix) cp_feat, _ = self.encode_lang2(cp_feat) cp_feat = cp_feat.reshape( batch_size, self.cfgs.max_sent * self.cfgs.max_sent_token, -1) batch_size, img_num, obj_num, feat_size = cp_feat.size( )[0], 1, cp_feat[1], cp_feat.size()[2] #print("-----------------------------\n",cp_feat.size()) assert img_num == 1 and feat_size == 2048 cp_feat = cp_feat.reshape(batch_size, -1, feat_size) output_lang, output_img, output_cross = self.bert_encoder( sents=que_feat, feats=cp_feat) # output_cross = output_cross.view(-1, self.hid_dim*2) ## original output_cross = output_cross.view(-1, self.hid_dim) ##print(output_lang.size(), output_img.size(), output_cross.size()) #### new experiment for relationship relate_lang_stack_1 = output_lang.view(output_lang.size()[0], 1, output_lang.size()[1], output_lang.size()[2]) relate_lang_stack_2 = output_lang.view(output_lang.size()[0], output_lang.size()[1], 1, output_lang.size()[2]) # relate_lang_stack = relate_lang_stack_1 + relate_lang_stack_2 ## [64, 20, 20, 768] relate_lang_stack_1 = relate_lang_stack_1.repeat( 1, output_lang.size()[1], 1, 1 ) ## [64, 20, 20, 768] second dim repeat 10 times, others not change relate_lang_stack_2 = relate_lang_stack_2.repeat( 1, 1, output_lang.size()[1], 1 ) ## [64, 20, 20, 768] third dim repeat 10 times, others not change relate_lang_stack = torch.cat( (relate_lang_stack_1, relate_lang_stack_2), 3) ## [64, 20, 20, 768*2] relate_lang_stack = relate_lang_stack.view(-1, output_lang.size()[2] * 2) relate_lang_stack = self.lang_2_to_1(relate_lang_stack) relate_lang_stack = relate_lang_stack.view(output_lang.size()[0], output_lang.size()[1], output_lang.size()[1], output_lang.size()[2]) relate_img_stack_1 = output_img.view( output_img.size()[0], 1, output_img.size()[1], output_img.size()[2]) #[16,1,1,768] relate_img_stack_2 = output_img.view( output_img.size()[0], output_img.size()[1], 1, output_img.size()[2]) #[16,1,1,768] ##print("r_i_s",relate_img_stack_1.size(),relate_img_stack_2.size()) relate_img_stack = relate_img_stack_1 + relate_img_stack_2 ## [16, 1, 1, 768] # relate_img_stack_1 = relate_img_stack_1.repeat(1,output_img.size()[1],1,1) ## [64, 20, 20, 768] second dim repeat 10 times, others not change # relate_img_stack_2 = relate_img_stack_2.repeat(1,1,output_img.size()[1],1) ## [64, 20, 20, 768] third dim repeat 10 times, others not change # relate_img_stack = torch.cat((relate_img_stack_1, relate_img_stack_2), 3) # relate_img_stack = relate_img_stack.view(-1, output_lang.size()[2]*2) # relate_img_stack = self.lang_2_to_1(relate_img_stack) # relate_img_stack = relate_img_stack.view(output_img.size()[0], output_img.size()[1], output_img.size()[1], output_img.size()[2]) relate_lang_stack = relate_lang_stack.view( relate_lang_stack.size()[0], relate_lang_stack.size()[1] * relate_lang_stack.size()[2], relate_lang_stack.size()[3]) ## [64, 400, 768] or 768*2 relate_img_stack = relate_img_stack.view( relate_img_stack.size()[0], relate_img_stack.size()[1] * relate_img_stack.size()[2], relate_img_stack.size()[3]) ## [64, 1296, 768] or 768*2 ##print("relate_i_s",relate_img_stack.size()) ### a beautiful way relate_lang_ind = torch.tril_indices(output_lang.size()[1], output_lang.size()[1], -1).cuda(0) relate_lang_ind[1] = relate_lang_ind[1] * output_lang.size()[1] relate_lang_ind = relate_lang_ind.sum(0) relate_lang_stack = relate_lang_stack.index_select( 1, relate_lang_ind) ## [64, 190, 768] or 768*2 relate_img_ind = torch.tril_indices(output_img.size()[1], output_img.size()[1], -1).cuda(0) #print("rii", relate_img_ind) relate_img_ind[1] = relate_img_ind[1] * output_img.size()[1] relate_img_ind = relate_img_ind.sum(0) #print("rii",relate_img_ind.size()) relate_img_stack = relate_img_stack.index_select( 1, relate_img_ind) ## [64, 630, 768] or 768*2 #print("ris", relate_img_stack.size()) ## reshape the relate_lang_stack and relate_img_stack tmp_lang_stack = relate_lang_stack.view(-1, self.hid_dim) # sum tmp_img_stack = relate_img_stack.view(-1, self.hid_dim) # sum # tmp_lang_stack = relate_lang_stack.view(-1, self.hid_dim*2) # cat # tmp_img_stack = relate_img_stack.view(-1, self.hid_dim*2) # cat lang_candidate_relat_score = self.lang_relation(tmp_lang_stack) img_candidate_relat_score = self.img_relation(tmp_img_stack) lang_candidate_relat_score = lang_candidate_relat_score.view( output_lang.size()[0], relate_lang_stack.size()[1]) ##(64, 190) img_candidate_relat_score = img_candidate_relat_score.view( output_img.size()[0], relate_img_stack.size()[1]) ## (64,630) ##print("icrs",img_candidate_relat_score) ##print("-1") #time.sleep(10) _, topk_lang_index = torch.topk(lang_candidate_relat_score, self.top_k_value, sorted=False) ##(64, 10) _, topk_img_index = torch.topk(img_candidate_relat_score, self.top_k_value, sorted=False) ##(64, 10) ##print("0") #time.sleep(10) list_lang_relat = [] list_img_relat = [] for i in range(0, output_lang.size()[0]): tmp = torch.index_select(relate_lang_stack[i], 0, topk_lang_index[i]) ## [10, 768] or 768*2 list_lang_relat.append(tmp) #print("1") #time.sleep(10) for i in range(0, output_img.size()[0]): tmp = torch.index_select(relate_img_stack[i], 0, topk_img_index[i]) #tmp = relate_img_stack[i] ## [10, 768] or 768*2 list_img_relat.append(tmp) #print("tmp",tmp.size()) #print("2") #time.sleep(10) #raise AssertionError #print("l-i-r", list_img_relat) lang_relat = torch.cat(list_lang_relat, 0) ## [640, 768] or 768*2 img_relat = torch.cat(list_img_relat, 0) ## [640, 768] or 768*2 #print("l-i-r", lang_relat.size(), img_relat.size()) lang_relat = lang_relat.view(output_lang.size()[0], -1, self.hid_dim) ## [64, 10, 768] or 768*2 img_relat = img_relat.view(output_img.size()[0], -1, self.hid_dim) ## [64, 10, 768] or 768*2 # lang_relat = lang_relat.view(output_lang.size()[0], -1, self.hid_dim*2) ## [64, 10, 768] or 768*2 # img_relat = img_relat.view(output_img.size()[0], -1, self.hid_dim*2) ## [64, 10, 768] or 768*2 #print("l-i-r",lang_relat.size(),img_relat.size()) relate_cross = torch.einsum('bld,brd->blr', F.normalize(lang_relat, p=2, dim=-1), F.normalize(img_relat, p=2, dim=-1)) #print("rc", relate_cross.size()) relate_cross = relate_cross.view(-1, 1, relate_cross.size()[1], relate_cross.size()[2]) realte_conv_1 = self.rel_pool1(F.relu(self.rel_conv1(relate_cross))) # realte_conv_2 = self.rel_pool2(F.relu(self.rel_conv2(realte_conv_1))) #print("rcsss", realte_conv_1.size()) relate_fc1 = F.relu(self.rel_fc1(realte_conv_1.view(-1, 32 * 4 * 4))) ##print(relate_fc1.size()) relate_fc1 = relate_fc1.view(-1, self.hid_dim) logit4 = self.logit_fc4(relate_fc1) #### new experiment for cross modality output_cross = output_cross.view(-1, output_cross.size()[1]) # #print("oc:",output_cross.size()) cross_tuple = torch.split(output_cross, output_cross.size()[1] // 2, dim=1) ##print("ct:",cross_tuple[0].size(),cross_tuple[1].size()) cross1 = cross_tuple[0].view(output_cross.size()[0], -1, 64) cross2 = cross_tuple[1].view(output_cross.size()[0], -1, 64) ##print("cross1",cross1.size(),cross2.size()) ##print("c",output_cross.size()) # #print(F.normalize(cross1,p=2,dim=-1).size()) cross_1_2 = torch.einsum('bld,brd->blr', F.normalize(cross1, p=2, dim=-1), F.normalize(cross2, p=2, dim=-1)) ##print("oc:", cross_1_2.size()) cross_1_2 = cross_1_2.view(-1, 1, cross_1_2.size()[1], cross_1_2.size()[2]) cross_conv_1 = self.cross_pool1(F.relu(self.cross_conv1(cross_1_2))) # cross_conv_2 = self.cross_pool2(F.relu(self.cross_conv2(cross_conv_1))) ''' #### new experiment for lang and two images cross_img_sen = torch.einsum( 'bld,brd->blr', F.normalize(output_lang, p=2, dim=-1), F.normalize(output_img, p=2, dim=-1) ) cross_img_sen = cross_img_sen.view(-1, 1, cross_img_sen.size()[1], cross_img_sen.size()[2]) entity_conv_1 = self.lang_pool1(F.relu(self.lang_conv1(cross_img_sen))) entity_conv_2 = self.lang_pool2(F.relu(self.lang_conv2(entity_conv_1))) ### new experiment for two images image_2_together = output_img.view(-1, output_img.size()[1], self.hid_dim * 2) # #print(image_2_together.size()) images = torch.split(image_2_together, self.hid_dim // 2, dim=2) # #print(images[0].size(), images[1].size()) image1 = images[0] image2 = images[1] ''' ''' cross_img_img = torch.einsum( 'bld,brd->blr', F.normalize(image1, p=2, dim=-1), F.normalize(image2, p=2, dim=-1) ) cross_img_img = cross_img_img.view(-1, 1, cross_img_img.size()[1], cross_img_img.size()[2]) cross_img_conv_1 = self.img_pool1(F.relu(self.img_conv1(cross_img_img))) cross_img_conv_2 = self.img_pool2(F.relu(self.img_conv2(cross_img_conv_1))) # #print(cross_img_conv_2.size()) img_fc1 = F.relu(self.img_fc1(cross_img_conv_2.view(-1, 32*7*7))) img_fc1 = img_fc1.view(-1, self.hid_dim) logit3 = self.logit_fc3(img_fc1) ''' ''' entity_fc1 = F.relu(self.lang_fc1(entity_conv_2.view(-1, 32 * 3 * 7))) entity_fc1 = entity_fc1.view(-1, self.hid_dim * 2) logit2 = self.logit_fc2(entity_fc1) ''' ##print('cc1',cross_conv_1.size()) cross_fc1 = F.relu(self.cross_fc1(cross_conv_1.view(-1, 32 * 2 * 2))) cross_fc1 = cross_fc1.view(-1, self.hid_dim) logit1 = self.logit_fc1(cross_fc1) # #print(logit1.size(),logit2.size(),logit3.size(),logit4.size()) #cross_logit = torch.cat((logit1, logit2, logit3, logit4), 1) cross_logit = torch.cat((logit1, logit4), 1) #print("final:",logit1.size(),logit4.size()) ##print("all:",cross_logit.size()) cross_logit = cross_logit.repeat(1, self.cfgs.max_ndq_ans).reshape( batch_size, self.cfgs.max_ndq_ans, -1) #print("after all:", cross_logit.size(),opt_feat.size()) cross_logit = torch.cat((cross_logit, opt_feat), dim=-1) #print("?",cross_logit.size()) logit = self.final_classifier(cross_logit).squeeze(-1) #print(opt_num.size()) logit = logit.masked_fill(opt_num, -1e9) return logit, torch.sum((opt_num == 0), dim=-1)
def forward(self, que_ix, opt_ix, dia, ins_dia, cp_ix): """ :param que_ix: the index of questions :param opt_ix: the index of options :param dia: the diagram corresponding the above question :param ins_dia: the instructional diagram corresponding to the lesson that contains the above question :param cp_ix: the closest paragraph that is extracted by TF-IDF method """ batch_size = que_ix.shape[0] que_mask = make_mask(que_ix.unsqueeze(2)) que_feat = self.embedding(que_ix) que_feat, _ = self.encode_lang(que_feat) opt_ix = opt_ix.reshape(-1, self.cfgs.max_opt_token) opt_mask = make_mask(opt_ix.unsqueeze(2)) opt_feat = self.embedding(opt_ix) opt_feat, _ = self.encode_lang(opt_feat) opt_feat = self.flat(opt_feat, opt_mask) opt_feat = opt_feat.reshape(batch_size, self.cfgs.max_dq_ans, -1) cp_ix = cp_ix.reshape(-1, self.cfgs.max_sent_token) cp_mask = make_mask(cp_ix.unsqueeze(2)) cp_feat = self.embedding(cp_ix) cp_feat, _ = self.encode_lang(cp_feat) cp_feat = self.flat(cp_feat, cp_mask) cp_feat = cp_feat.reshape(batch_size, self.cfgs.max_sent, -1) span_feat = get_span_feat(cp_feat, self.cfgs.span_width) cp_sent_mask = make_mask(cp_mask == False).reshape(batch_size, -1) cp_sent_mask = cp_sent_mask.unsqueeze(-1).repeat( 1, 1, self.cfgs.span_width).reshape(batch_size, -1) dia_feat = self.simclr(dia) # compute the entropy of questions flattened_que_feat = self.flat(que_feat, que_mask) que_entropy = - self.sigmoid(self.score_lang(flattened_que_feat)) * torch.log2( self.sigmoid(self.score_lang(flattened_que_feat))) - \ (1 - self.sigmoid(self.score_lang(flattened_que_feat))) * torch.log2( 1 - self.sigmoid(self.score_lang(flattened_que_feat))) # compute the conditional entropy of questions given candidate spans span_num = span_feat.shape[1] flattened_que_feat_expand = flattened_que_feat.repeat( 1, span_num).reshape(batch_size, span_num, -1) que_with_evi_feat = torch.cat((flattened_que_feat_expand, span_feat), dim=-1) que_with_evi_feat = self.pool3(que_with_evi_feat) span_feat = self.pool2(span_feat) evi_logit = self.sigmoid(self.score_lang(span_feat)) que_cond_ent = evi_logit * ( -self.sigmoid(self.score_lang(que_with_evi_feat)) * torch.log2(self.sigmoid(self.score_lang(que_with_evi_feat))) - (1 - self.sigmoid(self.score_lang(que_with_evi_feat))) * torch.log2(1 - self.sigmoid(self.score_lang(que_with_evi_feat)))) # compute the information gains of questions given evidence inf_gain = que_entropy - que_cond_ent.squeeze(-1) evi_ix = get_ix(inf_gain, cp_sent_mask) evi_feat = torch.cat( [span_feat[i][int(ix)] for i, ix in enumerate(evi_ix)], dim=0).reshape(batch_size, -1) evi_feat = evi_feat.repeat(1, self.cfgs.max_dq_ans).reshape( batch_size, self.cfgs.max_dq_ans, -1) # repeat flattened question features flattened_que_feat = flattened_que_feat.repeat( 1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans, -1) # use ban to fuse features of questions and diagrams fuse_que_dia = self.backbone(que_feat, dia_feat) fuse_que_dia = fuse_que_dia.mean(1).repeat( 1, self.cfgs.max_dq_ans).reshape(batch_size, self.cfgs.max_dq_ans, -1) # fuse features of evidence and options fuse_evi_opt = evi_feat * opt_feat # fuse features of questions and options fuse_que_opt = flattened_que_feat * opt_feat # fuse features of questions and evidence fuse_que_evi = flattened_que_feat * evi_feat # fuse features of questions, options, diagrams and evidence fuse_all = flattened_que_feat * opt_feat * fuse_que_dia dia_feat = dia_feat.repeat(1, self.cfgs.max_dq_ans).reshape( batch_size, self.cfgs.max_dq_ans, -1) fuse_feat = torch.cat( (flattened_que_feat, dia_feat, opt_feat, evi_feat, fuse_evi_opt, fuse_que_opt, fuse_que_evi, fuse_all), dim=-1) proj_feat = self.classifer(fuse_feat).squeeze(-1) return proj_feat