示例#1
0
文件: train.py 项目: zyksir/kbqa
def train_epoch(seq, seq_len, sub, rel, rel_len, crel, crel_label, crel_len,
                encoder, decoder):
    encoder_hidden = encoder(seq, seq_len, rel, sub, rel_len)
    if args.atten_mode != "BiLSTM":
        decoder_hidden = Variable(
            encoder_hidden)  # (1, batch_size, out_put_dim)

        # max_rel_len = 2
        batch_size = seq.size()[0]
        decoder_input = Variable(
            torch.LongTensor([rel_vocab.lookup("<_start>")] *
                             batch_size)).to(device)
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)

        # decoder_output: B * vocab_size
        crel_score = torch.gather(decoder_output, 1, crel)
    else:
        crel_score = torch.gather(encoder_hidden, 1, crel)
    # crel_score = Variable(torch.zeros(batch_size, crel.size(1))).to(device)
    # for i in range(batch_size):
    #     crel_score[i] = decoder_output[i].index_select(0, crel[i])

    mask = get_mask(crel_len).to(device)
    scores = torch.masked_select(crel_score, mask)
    labels = torch.masked_select(crel_label, mask)
    scores = sigmoid(scores)

    loss_scores = Variable(torch.zeros(scores.size(0), 2)).to(device)
    loss_scores[:, 0] = 1 - scores
    loss_scores[:, 1] = scores
    loss = criterion(loss_scores, labels)

    return loss
示例#2
0
文件: predict.py 项目: zyksir/kbqa
def new_evaluate(dataloader, encoder, decoder, device, start, biLSTM=False):
    label_total = []
    predict_total = []
    acc, recall, precision, f1, count = 0, 0, 0, 0, 0
    # for batch_idx, batch in enumerate(dataloader.next_batch()):
    for batch_idx, batch in tqdm(enumerate(dataloader.next_batch())):
        seq, seq_len, sub, rel, rel_len, crel, crel_label, crel_len = batch
        encoder_hidden = encoder(seq, seq_len, rel, sub, rel_len)
        decoder_output, decoder_hidden = decoder(encoder_hidden[:, :, :decoder.input_size].contiguous(),
                                                 encoder_hidden[:, :, decoder.input_size:].contiguous())
        crel_score = torch.gather(decoder_output, 1, crel)

        mask = get_mask(crel_len)
        scores = torch.masked_select(crel_score.cpu(), mask)
        labels = torch.masked_select(crel_label.cpu(), mask)
        label_total = list(labels)
        predict = sigmoid(scores)>0.5
        predict_total = list(predict)

        label_total = torch.LongTensor(label_total)
        predict_total = torch.LongTensor(predict_total)
        acc += accuracy_score(label_total, predict_total)
        recall += recall_score(label_total, predict_total)
        precision += precision_score(label_total, predict_total)
        f1 += f1_score(label_total, predict_total)
        count += 1

    return acc/count, recall/count, precision/count, f1/count
示例#3
0
文件: models.py 项目: zyksir/kbqa
    def forward(self, seq, seq_len, rel, sub, rel_len):
        seq = seq.transpose(0, 1)
        rel = rel.transpose(0, 1)

        seq = self.word_embedding(seq)
        rel = self.rel_embedding(rel)
        sub = None

        # rel_encode_output: (neg_size, batch, rel_dim)
        # rel_encode_hidden: (batch, sdim)
        # seq_encode_output: (length, batch, seq_in_size)
        # seq_encode_hidden: (batch, hidden)
        if self.mode == "self":
            rel_encode_output, rel_encode_hidden = self.RelationEncoder(
                rel.transpose(0, 1),
                slf_attn_mask=get_mask(rel_len, get_pad=False),
                non_pad_mask=get_mask(rel_len, get_pad=True))
        else:
            rel_encode_output, rel_encode_hidden = self.RelationEncoder(
                rel, sub, input_length=rel_len)
        seq_encode_output, seq_encode_hidden = self.QuestionEncoder(
            seq, seq_len, hidden=None)

        if self.mode == "both" or self.mode == "rel":
            rel_attention, rel_weight = self.rel_attention(
                seq_encode_hidden, rel_encode_output.transpose(0, 1), rel_len)
        else:
            rel_attention = rel_encode_hidden

        if self.mode == "both" or self.mode == "seq":
            seq_attention, seq_weight = self.seq_attention(
                rel_encode_hidden, seq_encode_output.transpose(0, 1), seq_len)
        else:
            seq_attention = seq_encode_hidden
        # print(seq_attention.size())
        # print(rel_attention.size())

        # encode_output: S*B*D
        # encode_hidden: NUM_LAYER*B*D

        hidden = torch.cat((seq_attention, rel_attention), 1)
        # hidden = self.matrix(hidden)
        hidden = hidden.unsqueeze(0)
        return hidden
示例#4
0
文件: predict.py 项目: zyksir/kbqa
def evaluate(dataloader, encoder, decoder, device, start, biLSTM=False):
    label_total = []
    predict_total = []
    acc, recall, precision, f1, count = 0, 0, 0, 0, 0
    # for batch_idx, batch in enumerate(dataloader.next_batch()):
    for batch_idx, batch in tqdm(enumerate(dataloader.next_batch())):
        seq, seq_len, sub, rel, rel_len, crel, crel_label, crel_len = batch
        encoder_hidden = encoder(seq, seq_len, rel, sub, rel_len)
        if biLSTM:
            decoder_output = encoder_hidden
        else:
            decoder_hidden = Variable(encoder_hidden)  # (1, batch_size, out_put_dim)

            # max_rel_len = 2
            batch_size = seq.size()[0]
            decoder_input = Variable(torch.LongTensor([start] * batch_size)).to(device)
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)

            # decoder_output: B * vocab_size
            # crel_score = torch.zeros(batch_size, crel.size(1))
            decoder_output = decoder_output

        crel_score = torch.gather(decoder_output, 1, crel)
        mask = get_mask(crel_len)
        scores = torch.masked_select(crel_score.cpu(), mask)
        labels = torch.masked_select(crel_label.cpu(), mask)
        label_total = list(labels)
        predict = sigmoid(scores)>0.5
        predict_total = list(predict)

        label_total = torch.LongTensor(label_total)
        predict_total = torch.LongTensor(predict_total)
        acc += accuracy_score(label_total, predict_total)
        recall += recall_score(label_total, predict_total)
        precision += precision_score(label_total, predict_total)
        f1 += f1_score(label_total, predict_total)
        count += 1

    return acc/count, recall/count, precision/count, f1/count
示例#5
0
文件: train.py 项目: zyksir/kbqa
    correct_rels = question2rel[row["formatted_question"]]
    if batch_size == 0:
        continue
    seq = torch.LongTensor(
        [seq_rel_vocab.convert_to_index(row["formatted_question"].split())] *
        batch_size).to(device)
    seq_len = torch.LongTensor([len(row["formatted_question"].split())] *
                               batch_size)
    sub = None
    rel_len = torch.LongTensor([len(x) for x in row["subject_relation"]])
    cand_rel = torch.LongTensor(batch_size, max(rel_len)).fill_(
        rel_vocab.lookup(rel_vocab.pad_token)).to(device)
    for i, rels in enumerate(row["subject_relation"]):
        cand_rel[i][:rel_len[i]] = torch.LongTensor(
            rel_vocab.convert_to_index(row["subject_relation"][i]))
    mask = get_mask(rel_len)
    encoder_hidden = encoder(seq, seq_len, cand_rel, sub, rel_len)
    # decoder_output, decoder_hidden = decoder(encoder_hidden[:, :, :decoder.input_size].contiguous(),
    #                                          encoder_hidden[:, :, decoder.input_size:].contiguous())
    decoder_hidden = Variable(encoder_hidden)  # (1, batch_size, out_put_dim)
    # max_rel_len = 2
    decoder_input = Variable(torch.LongTensor([start] * batch_size)).to(device)
    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
    crel_score = torch.zeros(batch_size, cand_rel.size(1))
    for i in range(batch_size):
        crel_score[i] = decoder_output[i].cpu().index_select(
            0, cand_rel[i].cpu())
    scores = sigmoid(torch.masked_select(crel_score, mask))

    candidate_relation = rel_vocab.convert_to_word(
        torch.masked_select(cand_rel.cpu(), mask).tolist())