def forward(self, data, config, gpu_list, acc_result, mode): context = data["context"] question = data["question"] batch = question.size()[0] option = question.size()[1] context = context.view(batch * option, -1) question = question.view(batch * option, -1) context = self.embedding(context) question = self.embedding(question) c, context = self.context_encoder(context) q, question = self.question_encoder(question) c = context.transpose(1, 2) q = question.transpose(1, 2) # c, q, a = self.attention(c, q) # y = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1) # y = torch.cat([torch.mean(context, dim=1), torch.mean(question, dim=1)], dim=1) y = torch.cat([c, q], dim=1) # y = y.reshape(batch, self.hidden_size, -1) nets_doc_l = [] for i in range(len(self.ngram_size)): nets = self.convs_doc[i](y) nets_doc_l.append(nets) nets_doc = torch.cat((nets_doc_l[0], nets_doc_l[1], nets_doc_l[2]), 2) poses_doc, activations_doc = self.primary_capsules_doc(nets_doc) poses, activations = self.flatten_capsules(poses_doc, activations_doc) poses, activations = self.compression(poses, self.W_doc) poses, type_logits = self.fc_capsules_doc_child( poses, activations, range(self.fc_size)) # 4 types in total. y = type_logits.squeeze(2) # y = y.view(batch * option, -1) # y = self.rank_module(y) # y = self.fc_module_q(a).squeeze(dim=2) # y = self.gelu(y) # y = self.dropout(y) # # # y = y.view(batch, option) y = self.fc_module(y) # y = torch.sigmoid(y) if mode != "test": label = data["label"] loss = self.bce(y, label) acc_result = self.accuracy_function(y, label, config, acc_result) return {"loss": loss, "acc_result": acc_result} return {"output": multi_generate_ans(data["id"], y)}
def forward(self, data, config, gpu_list, acc_result, mode): context = data["context"] question = data["question"] batch = question.size()[0] option = question.size()[1] # context = context.view(batch, -1) question = question.view(batch, -1) # context = self.embedding(context) question = self.embedding(question) context = self.context_encoder(context, self.context_len) _, question = self.question_encoder(question) c, q, a = self.attention(context[-1], question) # x = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1) x = torch.cat([c, q], dim=1) if torch.cuda.is_available(): x = x.transpose(1, 2).type(torch.cuda.FloatTensor) else: x = x.transpose(1, 2).type(torch.FloatTensor) nets_doc_l = [] for i in range(len(self.ngram_size)): nets = self.convs_doc[i](x) nets_doc_l.append(nets) nets_doc = torch.cat((nets_doc_l[0], nets_doc_l[1], nets_doc_l[2]), 2) poses_doc, activations_doc = self.primary_capsules_doc(nets_doc) poses, activations = self.flatten_capsules(poses_doc, activations_doc) poses, activations = self.compression(poses, self.W_doc) poses, type_logits = self.fc_capsules_doc_child( poses, activations, range(4)) # 4 types in total. type_logits = type_logits.squeeze(2) # type_logits = self.fc_module(type_logits) # y = x.view(batch * option, -1) # y = self.rank_module(y) # y = y.view(batch, option) # y = self.multi_module(y) if mode != "test": label = data["label"] # loss = self.criterion(type_logits, label) loss = self.bce(type_logits, label) acc_result = self.accuracy_function(type_logits, label, config, acc_result) return {"loss": loss, "acc_result": acc_result} return {"output": multi_generate_ans(data["id"], type_logits)}
def forward(self, data, config, gpu_list, acc_result, mode): context = data["context"] question = data["question"] batch = question.size()[0] option = question.size()[1] context = context.view(batch * option, -1) question = question.view(batch * option, -1) context = self.embedding(context) question = self.embedding(question) _, context = self.context_encoder(context) _, question = self.question_encoder(question) # c, q, a = self.attention(context, question) # y = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1) y = torch.cat([torch.mean(c, dim=1), torch.mean(q, dim=1)], dim=1) # y = y.view(batch * option, -1) # y = self.rank_module(y) # y = self.fc_module_q(a).squeeze(dim=2) y = self.gelu(y) y = self.dropout(y) # y = y.view(batch, option) y = self.fc_module(y) # y = torch.sigmoid(y) if mode != "test": label = data["label"] loss = self.bce(y, label) acc_result = self.accuracy_function(y, label, config, acc_result) return {"loss": loss, "acc_result": acc_result} return {"output": multi_generate_ans(data["id"], y)}