示例#1
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        question = data["question"]

        batch = question.size()[0]
        option = question.size()[1]

        context = context.view(batch * option, -1)
        question = question.view(batch * option, -1)
        context = self.embedding(context)
        question = self.embedding(question)

        _, context = self.context_encoder(context)
        _, question = self.question_encoder(question)

        c, q, a = self.attention(context, question)

        y = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1)

        y = y.view(batch * option, -1)
        y = self.rank_module(y)

        y = y.view(batch, option)

        y = self.multi_module(y)

        if mode != "test":
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#2
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        question = data["question"]
        batch = question[0].size()[0]
        seq_len = question[0].size()[1]
        context, _ = self.context_encoder(*context)
        if seq_len > 512:
            n = seq_len // 512
            a, b, c = question
            temp_question = None
            for i in range(n):
                _a, _b, _c = a[:, i * 512:(i + 1) *
                               512], b[:, i * 512:(i + 1) *
                                       512], c[:, i * 512:(i + 1) * 512]
                if torch.any(_b.bool()):
                    _question, _ = self.context_encoder(_a, _b, _c)
                    if i:
                        temp_question += _question
                    else:
                        temp_question = _question
            question = temp_question
        else:
            question, _ = self.context_encoder(*question)

        # context = context[-1]
        # question = question[-1]
        context = context.view(batch, -1, self.hidden_size)
        question = question.view(batch, -1, self.hidden_size)

        # c, q, a = self.attention(context, question)
        c, q = context, question
        # c = torch.mean(c, dim=1)
        # q = torch.mean(q, dim=1)
        y = torch.cat([c, q], dim=1)

        nets_doc_l = []
        for i in range(len(self.ngram_size)):
            nets = self.convs_doc[i](y)
            nets_doc_l.append(nets)
        nets_doc = torch.cat((nets_doc_l[0], nets_doc_l[1], nets_doc_l[2]),
                             dim=2)
        poses_doc, activations_doc = self.primary_capsules_doc(nets_doc)
        poses, activations = self.flatten_capsules(poses_doc, activations_doc)
        poses, activations = self.compression(poses, self.W_doc)
        poses, type_logits = self.fc_capsules_doc_child(
            poses, activations, range(4))  #4 types in total.
        y = type_logits.squeeze(2)

        y = self.dropout(y)
        #y = self.rank_module(y)

        if mode != "test":
            label = data["label"]
            loss = self.bce(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#3
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        question = data["question"]
        if_positive = data["if_positive"]

        batch = question[0].size()[0]
        seq_len = question[0].size()[1]
        context, _ = self.context_encoder(*context)
        if seq_len > 512:
            window = 512
            n = seq_len // window
            a, b, c = question
            temp_question = None
            for i in range(n):
                _a, _b, _c = a[:, i * window:(i + 1) *
                               window], b[:, i * window:(i + 1) *
                                          window], c[:, i * window:(i + 1) *
                                                     window]
                if torch.any(_b.bool()):
                    _question, _ = self.question_encoder(_a, _b, _c)
                    if i:
                        temp_question += _question
                    else:
                        temp_question = _question
            question = temp_question
        else:
            question, _ = self.question_encoder(*question)

        question = question.view(batch, -1, self.hidden_size)
        context = context.view(batch, -1, self.hidden_size)

        c, q = context, question
        # c = torch.mean(c, dim=1)
        # q = torch.mean(q, dim=1)
        y = torch.cat([c, q], dim=1)
        y = y.view(batch, -1, 3, self.hidden_size // 3)
        y = y.transpose(1, 2)
        # y = y.view(batch, -1)
        # y = torch.cat([y_re, if_positive, y_de], dim=1)
        y = self.resnet(y)
        # y = self.dropout(y)
        y = self.res_module(y)
        y = torch.cat([y, if_positive], dim=1)
        y = self.fc_module(y)
        y = self.softmax(y)

        if mode != "test":
            label = data["label"]
            loss = self.bce(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#4
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        if_positive = data["if_positive"]
        question = data["question"]

        batch = question.size()[0]
        option = question.size()[1]

        context = context.view(batch * option, -1)
        # context_score = context_score.view(batch * option, -1)
        question = question.view(batch * option, -1)
        context = self.embedding(context)
        question = self.embedding(question)

        c, context = self.context_encoder(context)
        q, question = self.question_encoder(question)

        # c, q, a = self.attention(context, question)
        # # c = torch.max(c, dim=1)[0]
        # # q = torch.max(q, dim=1)[0]
        # c = torch.mean(c, dim=1)
        # q = torch.mean(q, dim=1)
        #c, q = context, question
        # c = torch.mean(c, dim=1)
        # q = torch.mean(q, dim=1)
        y = torch.cat([c, q], dim=1)
        y = y.view(batch, -1, 3, self.hidden_size // 3)
        y = y.transpose(1, 2)
        # y_se = self.seresnet(y)
        y = self.resnet(y)
        # y_de = self.densenet(y)

        # y = torch.cat([y_re, if_positive, y_de], dim=1)
        y = self.dropout(y)
        y = self.res_module(y)
        y = torch.cat([y, if_positive], dim=1)
        y = self.fc_module(y)
        y = self.softmax(y)

        if mode != "test":
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#5
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        question = data["question"]
        batch = question[0].size()[0]
        seq_len = question[0].size()[1]
        context, _ = self.context_encoder(*context)
        if seq_len > 512:
            n = seq_len // 512
            a, b, c = question
            temp_question = None
            for i in range(n):
                _a, _b, _c = a[:, i * 512:(i + 1) *
                               512], b[:, i * 512:(i + 1) *
                                       512], c[:, i * 512:(i + 1) * 512]
                _question, _ = self.context_encoder(_a, _b, _c)
                if i:
                    temp_question += _question
                else:
                    temp_question = _question
            question = temp_question
        else:
            question, _ = self.context_encoder(*question)

        # context = context[-1]
        # question = question[-1]
        context = context.view(batch, -1, self.hidden_size)
        question = question.view(batch, -1, self.hidden_size)

        c, q = context, question
        # c = torch.mean(c, dim=1)
        # q = torch.mean(q, dim=1)
        y = torch.cat([c, q], dim=1)
        y = y.view(batch, -1, 3, self.hidden_size // 3)
        y = y.transpose(1, 2)
        y = self.seresnet(y)
        # y = y.view(batch, -1)
        #y = self.dropout(y)
        y = self.rank_module(y)

        if mode != "test":
            label = data["label"]
            loss = self.bce(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#6
0
文件: HAF.py 项目: thunlp/jec-qa
    def forward(self, data, config, gpu_list, acc_result, mode):
        passage = data["passage"]
        question = data["question"]
        option = data["option"]

        batch = question.size()[0]
        option_num = option.size()[1]
        k = config.getint("data", "topk")

        passage = passage.view(batch * option_num * k, -1)
        question = question.view(batch, -1)
        option = option.view(batch * option_num, -1)
        # print(passage.size(), question.size(), option.size())

        passage = self.embedding(passage)
        question = self.embedding(question)
        option = self.embedding(option)
        # print(passage.size(), question.size(), option.size())

        _, passage = self.passage_encoder(passage)
        _, question = self.question_encoder(question)
        _, option = self.option_encoder(option)
        # print(passage.size(), question.size(), option.size())

        passage = passage.view(batch * option_num * k, -1, self.hidden_size)
        question = question.view(batch, 1, 1, -1, self.hidden_size).repeat(1, option_num, k, 1, 1).view(
            batch * option_num * k, -1, self.hidden_size)
        option = option.view(batch, option_num, 1, -1, self.hidden_size).repeat(1, 1, k, 1, 1).view(
            batch * option_num * k, -1, self.hidden_size)
        # print(passage.size(), question.size(), option.size())

        vp = self.q2p(question, passage).view(batch * option_num * k, -1, 1)
        # print("vp", vp.size())
        vp = vp * passage
        # print("vp", vp.size())
        vo = self.q2o(question, option).view(batch * option_num * k, -1, 1)
        # print("vo", vo.size())
        vo = vo * option
        # print("vo", vo.size())
        _, vpp = self.s(vp)
        # print("vpp", vpp.size())
        rp = self.q2o(vo, vpp).view(batch * option_num * k, -1, 1)
        # print("rp", rp.size())
        rp = rp * vpp
        # print("rp", rp.size())
        vop = self.oc(vo, vo).view(batch * option_num * k, -1, 1)
        # print("vop", vop.size())
        vop = vop * vo
        # print("vop", vop.size())
        ro = torch.cat([vo, vo - vop], dim=2)
        # print("ro", ro.size())

        s = self.wp(rp)
        ro = torch.transpose(ro, 2, 1)
        s = torch.bmm(s, ro)
        # print(s.size())
        s = s.view(batch * option_num, -1)
        s = self.score(s)
        y = s.view(batch, option_num)
        # print(y.size())
        # gg

        if self.multi:
            y = self.multi_module(y)

        if mode != "test":
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}
示例#7
0
    def forward(self, data, config, gpu_list, acc_result, mode):
        context = data["context"]
        question = data["question"]

        batch = question[0].size()[0]

        _, _, bert_question = self.context_encoder(*question)
        _, _, bert_context = self.context_encoder(*context)

        bert_context = bert_context[-1].reshape(batch, 4, self.context_len, -1)
        bert_question = bert_question[-1]

        contextpool = []
        for i in range(4):
            option = bert_context[:, i, :, :].squeeze(1)
            c, q, a = self.attention(option, bert_question)
            contextpool.append(c)
            question = q

        # cp = torch.cat(contextpool, dim=1)
        # y = y.unsqueeze(1)

        # y = self.conv5(y)
        # y = torch.sigmoid(y)
        # y = self.conv2(y)
        # y = self.conv3(y)
        # y = self.conv4(y)
        # y = self.conv5(y)
        # y = self.gelu(y)
        # y = self.conv4(y)
        # y = torch.sigmoid(y)

        # context = bert_context[-1].view(batch, -1, self.context_len, self.hidden_size)
        # question = bert_question[-1].view(batch,4, self.context_len, self.hidden_size)

        # context_2 = bert_context[-2].view(batch, -1, self.context_len, self.hidden_size)
        # question_2 = bert_question[-2].view(batch,4, self.context_len, self.hidden_size)
        #
        # context = torch.cat([context_1,context_2], dim=1)
        # question = torch.cat([question_1, question_2], dim=1)
        # context = (context_1 + context_2)/2
        # question = (question_1 + question_2)/2
        # y = torch.cat([context, question, context_2, question_2], dim=1)
        # y = question.view(batch,4, self.context_len, self.context_len, -1)
        # y = y.permute(0,1,4,2,3)
        # y = y.reshape(batch, -1, self.context_len, self.context_len)
        # y = y.reshape((batch, -1))

        y = torch.cat([torch.max(c, dim=1)[0] for c in contextpool], dim=1)
        # a = self.att_flow_layer(context, question)
        # c, q = context, question
        # ymax = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1)
        # ymean = torch.cat([torch.mean(c, dim=1), torch.mean(q, dim=1)], dim=1)
        # y = torch.cat([ymax,ymean], dim=1)
        # y = self.resnet(y)
        # y = self.gelu(y)
        y = y.flatten(start_dim=1)
        # y = self.dropout(y)
        y = self.fc_module(y)
        # y = self.softmax(y)

        if mode != "test":
            label = data["label"]
            loss = self.bce(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {"output": generate_ans(data["id"], y)}