Ejemplo n.º 1
0
    def forward(self, batch):
        entities, answer_dist = batch

        # numpy to tensor
        entities = use_cuda(
            Variable(torch.from_numpy(entities).type('torch.LongTensor'),
                     requires_grad=False))
        answer_dist = use_cuda(
            Variable(torch.from_numpy(answer_dist).type('torch.LongTensor'),
                     requires_grad=False))

        # entity embedding
        entity_emb = self.entity_embedding(entities)
        if self.has_entity_kge:
            entity_emb = torch.cat(
                (entity_emb, self.entity_kge(entities)),
                dim=2)  # batch_size, max_local_entity, word_dim + kge_dim
        if self.word_dim != self.entity_dim:
            entity_emb = self.entity_linear1(self.linear_drop(
                entity_emb))  # batch_size, max_local_entity, entity_dim
        entity_emb = self.relu(
            self.entity_linear2(self.linear_drop(entity_emb)))
        score = self.relu(self.entity_linear3(self.linear_drop(entity_emb)))

        score = score.squeeze()
        loss = self.cross_loss(score, answer_dist)
        pred_dist = self.sigmoid(score)
        pred = torch.max(pred_dist, dim=1)[1]

        return loss, pred, pred_dist
Ejemplo n.º 2
0
def get_model(cfg, num_kb_relation, num_entities, num_vocab):
    word_emb_file = None if cfg[
        'word_emb_file'] is None else cfg['data_folder'] + cfg['word_emb_file']
    entity_emb_file = None if cfg['entity_emb_file'] is None else cfg[
        'data_folder'] + cfg['entity_emb_file']
    entity_kge_file = None if cfg['entity_kge_file'] is None else cfg[
        'data_folder'] + cfg['entity_kge_file']
    relation_emb_file = None if cfg['relation_emb_file'] is None else cfg[
        'data_folder'] + cfg['relation_emb_file']
    relation_kge_file = None if cfg['relation_kge_file'] is None else cfg[
        'data_folder'] + cfg['relation_kge_file']

    my_model = use_cuda(
        GraftNet(word_emb_file, entity_emb_file, entity_kge_file,
                 relation_emb_file, relation_kge_file, cfg['num_layer'],
                 num_kb_relation, num_entities, num_vocab, cfg['entity_dim'],
                 cfg['word_dim'], cfg['kge_dim'], cfg['pagerank_lambda'],
                 cfg['fact_scale'], cfg['lstm_dropout'], cfg['linear_dropout'],
                 cfg['use_kb'], cfg['use_doc']))

    if cfg['load_model_file'] is not None:
        print('loading model from', cfg['load_model_file'])
        pretrained_model_states = torch.load(cfg['load_model_file'])
        if word_emb_file is not None:
            del pretrained_model_states['word_embedding.weight']
        if entity_emb_file is not None:
            del pretrained_model_states['entity_embedding.weight']
        my_model.load_state_dict(pretrained_model_states, strict=False)

    return my_model
Ejemplo n.º 3
0
def gen(model, input, size=100, temp=0.5):
    """
	Generate data from the model
	"""
    with torch.no_grad():
        data = []
        # Init input
        if util.use_cuda():
            input = input.cuda()
            model = model.cuda()
        input = Variable(input)
        #  Print the input
        print('IN: [', end='', flush=True)
        for i in input:
            print(i.item(), end=', ', flush=True)
        print(']', end='', flush=True)
        print()
        # Get generated data

        for i in range(size):
            output = model(input[None, :])
            c = sample(output[0, -1, :], temp)
            #print(" SAMPLE: ", c)
            #print(str(chr(max(32, c))), end='', flush=True)
            data.append(c.item())
            # Make next prediction informed by this prediction
            input = torch.cat([input[1:], c[None]], dim=0)
        #print(data[:30])
        #print()
        #print(pitches)
        return data
Ejemplo n.º 4
0
def get_model(entity2id):
    my_model = use_cuda(
        Classifier(entity_emb_file, entity_kge_file, len(entity2id), word_dim,
                   kge_dim, entity_dim, linear_dropout))
    if load_model is not None:
        print("loading model from", load_model)
        pretrained_model_states = torch.load(load_model)
        if entity_emb_file is not None:
            del pretrained_model_states['entity_embedding.weight']
        del pretrained_model_states['cross_loss.weight']
        my_model.load_state_dict(pretrained_model_states, strict=False)
    return my_model
Ejemplo n.º 5
0
def get_model(cfg, num_kb_relation, num_entities, num_vocab):
    # word_emb_file = None if cfg['word_emb_file'] is None else cfg['data_folder'] + cfg['word_emb_file']
    # entity_emb_file = None if cfg['entity_emb_file'] is None else cfg['data_folder'] + cfg['entity_emb_file']
    # entity_kge_file = None if cfg['entity_kge_file'] is None else cfg['data_folder'] + cfg['entity_kge_file']
    # relation_emb_file = None if cfg['relation_emb_file'] is None else cfg['data_folder'] + cfg['relation_emb_file']
    # relation_kge_file = None if cfg['relation_kge_file'] is None else cfg['data_folder'] + cfg['relation_kge_file']

    #my_model = cuda(GraftNet(word_emb_file, entity_emb_file, entity_kge_file, relation_emb_file, relation_kge_file, cfg['num_layer'], num_kb_relation, num_entities, num_vocab, cfg['entity_dim'], cfg['word_dim'], cfg['kge_dim'], cfg['pagerank_lambda'], cfg['fact_scale'], cfg['lstm_dropout'], cfg['linear_dropout'], cfg['use_kb'], cfg['use_doc']))
    # cuda_condition = torch.cuda.is_available() and with_cuda
    # self.device = torch.device("cuda:0" if cuda_condition else "cpu")

    my_model = use_cuda(GraftNet(num_kb_relation, num_entities, num_vocab,
                                 cfg))
    if cfg['load_model_file'] is not None:
        print('loading model from', cfg['load_model_file'])
        pretrained_model_states = torch.load(cfg['load_model_file'])
        if cfg['word_emb_file'] is not None:
            del pretrained_model_states['word_embedding.weight']
        if cfg['entity_emb_file'] is not None:
            del pretrained_model_states['entity_embedding.weight']
        my_model.load_state_dict(pretrained_model_states, strict=False)

    return my_model
Ejemplo n.º 6
0
import torch
from torch.autograd import Variable
import torch.nn as nn
from util import use_cuda
import numpy as np

answer_type = 6
weight = use_cuda(
    Variable(torch.Tensor([1.26, 6.89, 23.66, 149.49, 104.40, 505.41])))


class Classifier(nn.Module):
    def __init__(self, pretrained_entity_emb_file, pretrained_entity_kge_file,
                 num_entity, word_dim, kge_dim, entity_dim, linear_dropout):
        super(Classifier, self).__init__()
        self.has_entity_kge = True
        self.num_entity = num_entity
        self.entity_dim = entity_dim
        self.word_dim = word_dim
        self.kge_dim = kge_dim

        # initialize entity embedding
        self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1,
                                             embedding_dim=word_dim,
                                             padding_idx=num_entity)
        if pretrained_entity_emb_file is not None:
            self.entity_embedding.weight = nn.Parameter(
                torch.from_numpy(
                    np.pad(np.load(pretrained_entity_emb_file),
                           ((0, 1), (0, 0)),
                           'constant')).type('torch.FloatTensor'))
Ejemplo n.º 7
0
def train(n_heads=8,
          depth=4,
          seq_length=32,
          n_tokens=256,
          emb_size=128,
          n_batches=500,
          batch_size=64,
          test_every=50,
          lr=0.0001,
          warmup=100,
          seed=-1,
          data_sub=1000,
          output_path="genmodel.pt"):
    """
	Train the model and save it to output_path
	"""
    # Seed the network
    if (seed < 0):
        seed = random.randint(0, 1000000)
        print("Using seed: ", seed)
    else:
        torch.manual_seed(seed)

    # Load training data
    data_train, data_valid = get_data()
    losses = []
    # Create the model
    model = tf.GenTransformer(emb=emb_size,
                              n_heads=n_heads,
                              depth=depth,
                              seq_length=seq_length,
                              n_tokens=n_tokens)
    if util.use_cuda():
        model = model.cuda()
    # Optimizer
    opt = torch.optim.Adam(model.parameters(), lr)
    # Train over batches of random sequences
    for i in tqdm.trange(n_batches - 1):  # tqdm is a nice progress bar
        # Warming up learning rate by linearly increasing to the provided learning rate
        if lr > 0 and i < warmup:
            lr = max((lr / warmup) * i, 1e-10)
            opt.lr = lr
        # Prevent gradient accumulation
        opt.zero_grad()
        # Sample batch of random subsequences
        starts = torch.randint(size=(batch_size, ),
                               low=0,
                               high=data_train.size(0) - seq_length - 1)
        seqs_source = [
            data_train[start:start + seq_length] for start in starts
        ]
        # The target is the same as the source sequence except one character ahead
        seqs_target = [
            data_train[start + 1:start + seq_length + 1] for start in starts
        ]
        source = torch.cat([s[None, :] for s in seqs_source],
                           dim=0).to(torch.long)
        target = torch.cat([s[None, :] for s in seqs_target],
                           dim=0).to(torch.long)
        # Get cuda
        if util.use_cuda():
            source, target = source.cuda(), target.cuda()
        source, target = Variable(source), Variable(target)
        # Initialize the output
        output = model(source)
        # Get the loss
        loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean')
        loss.backward()
        losses.append(loss.item())
        # Clip the gradients
        nn.utils.clip_grad_norm_(model.parameters(), 1)

        # Perform optimization step
        opt.step()
        # Validate every so often, compute compression then generate
        if i != 0 and (i % test_every == 0 or i == n_batches - 1):
            # TODO sort of arbitrary, make this rigorous
            upto = data_valid.size(0) if i == n_batches - 1 else 100
            data_sub = data_valid[:upto]
            #
            with torch.no_grad():
                bits = 0.0
                # When this buffer is full we run it through the model
                batch = []
                for current in range(data_sub.size(0)):
                    fr = max(0, current - seq_length)
                    to = current + 1
                    context = data_sub[fr:to].to(torch.long)
                    # If the data doesnt fit the sequence length pad it
                    if context.size(0) < seq_length + 1:
                        pad = torch.zeros(size=(seq_length + 1 -
                                                context.size(0), ),
                                          dtype=torch.long)
                        context = torch.cat([pad, context], dim=0)
                        assert context.size(0) == seq_length + 1
                    # Get cuda
                    if util.use_cuda():
                        context = context.cuda()
                    # Fill the batch
                    batch.append(context[None, :])
                    # Check if the batch is full
                    if len(batch
                           ) == batch_size or current == data_sub.size(0) - 1:
                        # Run through model
                        b = len(batch)
                        all = torch.cat(batch, dim=0)
                        source = all[:, :-1]  # Input
                        target = all[:, -1]  # Target values
                        #
                        output = model(source)
                        # Get probabilities and convert to bits
                        lnprobs = output[torch.arange(b, device=util.device()),
                                         -1, target]
                        log2probs = lnprobs * math.log2(math.e)
                        # For logging
                        bits += log2probs.sum()
                        # Empty batch buffer
                        batch = []
                # Print validation performance
                bits_per_byte = abs(bits / data_sub.size(0))
                print(f' epoch {i}: {bits_per_byte:.4} bits per byte')
                print("Loss:", loss.item())
                # Monitor progress by generating data based on the validation data
                seedfr = random.randint(0, data_valid.size(0) - seq_length)
                input = data_valid[seedfr:seedfr + seq_length].to(torch.long)
                output_valid = gen(model, input)
                print("OUT:", output_valid[:30])
    util.save_model(model, output_path)
    return losses

    # Save the model when we're done training it
    #
    print("Finished training. Model saved to", output_path)
Ejemplo n.º 8
0
 def init_hidden(self, num_layer, batch_size, hidden_size):
     return (use_cuda(
         Variable(torch.zeros(num_layer, batch_size, hidden_size))),
             use_cuda(
                 Variable(torch.zeros(num_layer, batch_size, hidden_size))))
Ejemplo n.º 9
0
    def forward(self, batch):
        """
        :local_entity: global_id of each entity                     (batch_size, max_local_entity)
        :q2e_adj_mat: adjacency matrices (dense)                    (batch_size, max_local_entity, 1)
        :kb_adj_mat: adjacency matrices (sparse)                    (batch_size, max_fact, max_local_entity), (batch_size, max_local_entity, max_fact)
        :kb_fact_rel:                                               (batch_size, max_fact)
        :query_text: a list of words in the query                   (batch_size, max_query_word)
        :document_text:                                             (batch_size, max_relevant_doc, max_document_word)
        :entity_pos: sparse entity_pos_mat                          (batch_size, max_local_entity, max_relevant_doc * max_document_word) 
        :answer_dist: an distribution over local_entity             (batch_size, max_local_entity)
        """
        local_entity, q2e_adj_mat, kb_adj_mat, kb_fact_rel, query_text, document_text, entity_pos, answer_dist = batch

        batch_size, max_local_entity = local_entity.shape
        _, max_relevant_doc, max_document_word = document_text.shape
        _, max_fact = kb_fact_rel.shape

        # numpy to tensor
        local_entity = use_cuda(
            Variable(torch.from_numpy(local_entity).type('torch.LongTensor'),
                     requires_grad=False))
        local_entity_mask = use_cuda(
            (local_entity != self.num_entity).type('torch.FloatTensor'))
        if self.use_kb:
            kb_fact_rel = use_cuda(
                Variable(
                    torch.from_numpy(kb_fact_rel).type('torch.LongTensor'),
                    requires_grad=False))
        query_text = use_cuda(
            Variable(torch.from_numpy(query_text).type('torch.LongTensor'),
                     requires_grad=False))
        query_mask = use_cuda(
            (query_text != self.num_word).type('torch.FloatTensor'))
        if self.use_doc:
            document_text = use_cuda(
                Variable(
                    torch.from_numpy(document_text).type('torch.LongTensor'),
                    requires_grad=False))
            document_mask = use_cuda(
                (document_text != self.num_word).type('torch.FloatTensor'))
        answer_dist = use_cuda(
            Variable(torch.from_numpy(answer_dist).type('torch.FloatTensor'),
                     requires_grad=False))

        # normalized adj matrix
        pagerank_f = use_cuda(
            Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'),
                     requires_grad=True)).squeeze(
                         dim=2)  # batch_size, max_local_entity
        pagerank_d = use_cuda(
            Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'),
                     requires_grad=False)).squeeze(
                         dim=2)  # batch_size, max_local_entity
        q2e_adj_mat = use_cuda(
            Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'),
                     requires_grad=False))  # batch_size, max_local_entity, 1
        assert pagerank_f.requires_grad == True
        assert pagerank_d.requires_grad == False

        # encode query
        query_word_emb = self.word_embedding(
            query_text)  # batch_size, max_query_word, word_dim
        query_hidden_emb, (query_node_emb, _) = self.node_encoder(
            self.lstm_drop(query_word_emb),
            self.init_hidden(1, batch_size,
                             self.entity_dim))  # 1, batch_size, entity_dim
        query_node_emb = query_node_emb.squeeze(dim=0).unsqueeze(
            dim=1)  # batch_size, 1, entity_dim
        query_rel_emb = query_node_emb  # batch_size, 1, entity_dim

        if self.use_kb:
            # build kb_adj_matrix from sparse matrix
            (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f,
                                                 f2e_val) = kb_adj_mat
            entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e])
            entity2fact_val = torch.FloatTensor(e2f_val)
            entity2fact_mat = use_cuda(
                torch.sparse.FloatTensor(
                    entity2fact_index, entity2fact_val,
                    torch.Size([batch_size, max_fact, max_local_entity
                                ])))  # batch_size, max_fact, max_local_entity

            fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f])
            fact2entity_val = torch.FloatTensor(f2e_val)
            fact2entity_mat = use_cuda(
                torch.sparse.FloatTensor(
                    fact2entity_index, fact2entity_val,
                    torch.Size([batch_size, max_local_entity, max_fact])))

            # load fact embedding
            local_fact_emb = self.relation_embedding(
                kb_fact_rel)  # batch_size, max_fact, 2 * word_dim
            if self.has_relation_kge:
                local_fact_emb = torch.cat(
                    (local_fact_emb, self.relation_kge(kb_fact_rel)),
                    dim=2)  # batch_size, max_fact, 2 * word_dim + kge_dim
            local_fact_emb = self.relation_linear(
                local_fact_emb)  # batch_size, max_fact, entity_dim

            # attention fact2question
            div = float(np.sqrt(self.entity_dim))
            fact2query_sim = torch.bmm(
                query_hidden_emb, local_fact_emb.transpose(
                    1, 2)) / div  # batch_size, max_query_word, max_fact
            fact2query_sim = self.softmax_d1(
                fact2query_sim + (1 - query_mask.unsqueeze(dim=2)) *
                VERY_NEG_NUMBER)  # batch_size, max_query_word, max_fact
            fact2query_att = torch.sum(
                fact2query_sim.unsqueeze(dim=3) *
                query_hidden_emb.unsqueeze(dim=2),
                dim=1)  # batch_size, max_fact, entity_dim
            W = torch.sum(fact2query_att * local_fact_emb,
                          dim=2) / div  # batch_size, max_fact
            W_max = torch.max(W, dim=1, keepdim=True)[0]  # batch_size, 1
            W_tilde = torch.exp(W - W_max)  # batch_size, max_fact
            e2f_softmax = sparse_bmm(entity2fact_mat.transpose(1, 2),
                                     W_tilde.unsqueeze(dim=2)).squeeze(
                                         dim=2)  # batch_size, max_local_entity
            e2f_softmax = torch.clamp(e2f_softmax, min=VERY_SMALL_NUMBER)
            e2f_out_dim = use_cuda(
                Variable(torch.sum(entity2fact_mat.to_dense(), dim=1),
                         requires_grad=False))  # batch_size, max_local_entity

        # build entity_pos matrix
        if self.use_doc:
            entity_pos_dim_batch, entity_pos_dim_entity, entity_pos_dim_doc_by_word, entity_pos_value = entity_pos
            entity_pos_index = torch.LongTensor([
                entity_pos_dim_batch, entity_pos_dim_entity,
                entity_pos_dim_doc_by_word
            ])
            entity_pos_val = torch.FloatTensor(entity_pos_value)
            entity_pos_mat = use_cuda(
                torch.sparse.FloatTensor(
                    entity_pos_index, entity_pos_val,
                    torch.Size([
                        batch_size, max_local_entity,
                        max_relevant_doc * max_document_word
                    ]))
            )  # batch_size, max_local_entity, max_relevant_doc * max_document_word
            d2e_adj_mat = torch.sum(
                entity_pos_mat.to_dense().view(batch_size, max_local_entity,
                                               max_relevant_doc,
                                               max_document_word),
                dim=3)  # batch_size, max_local_entity, max_relevant_doc
            d2e_adj_mat = use_cuda(Variable(d2e_adj_mat, requires_grad=False))
            e2d_out_dim = torch.sum(
                torch.sum(entity_pos_mat.to_dense().view(
                    batch_size, max_local_entity, max_relevant_doc,
                    max_document_word),
                          dim=3),
                dim=2,
                keepdim=True)  # batch_size, max_local_entity, 1
            e2d_out_dim = use_cuda(Variable(e2d_out_dim, requires_grad=False))
            e2d_out_dim = torch.clamp(e2d_out_dim, min=VERY_SMALL_NUMBER)

            d2e_out_dim = torch.sum(
                torch.sum(entity_pos_mat.to_dense().view(
                    batch_size, max_local_entity, max_relevant_doc,
                    max_document_word),
                          dim=3),
                dim=1,
                keepdim=True)  # batch_size, 1, max_relevant_doc
            d2e_out_dim = use_cuda(Variable(d2e_out_dim, requires_grad=False))
            d2e_out_dim = torch.clamp(d2e_out_dim,
                                      min=VERY_SMALL_NUMBER).transpose(1, 2)

            # encode document
            document_textual_emb = self.word_embedding(
                document_text.view(batch_size * max_relevant_doc,
                                   max_document_word)
            )  # batch_size * max_relevant_doc, max_document_word, entity_dim
            document_textual_emb, (document_node_emb, _) = read_padded(
                self.bi_text_encoder, self.lstm_drop(document_textual_emb),
                document_mask.view(-1, max_document_word)
            )  # batch_size * max_relevant_doc, max_document_word, entity_dim
            document_textual_emb = document_textual_emb[:, :, 0:self.
                                                        entity_dim] + document_textual_emb[:, :,
                                                                                           self
                                                                                           .
                                                                                           entity_dim:]
            document_textual_emb = document_textual_emb.contiguous().view(
                batch_size, max_relevant_doc, max_document_word,
                self.entity_dim)
            document_node_emb = (
                document_node_emb[0, :, :] + document_node_emb[1, :, :]
            ).view(batch_size, max_relevant_doc,
                   self.entity_dim)  # batch_size, max_relevant_doc, entity_dim

        # load entity embedding
        local_entity_emb = self.entity_embedding(
            local_entity)  # batch_size, max_local_entity, word_dim
        if self.has_entity_kge:
            local_entity_emb = torch.cat(
                (local_entity_emb, self.entity_kge(local_entity)),
                dim=2)  # batch_size, max_local_entity, word_dim + kge_dim
        if self.word_dim != self.entity_dim:
            local_entity_emb = self.entity_linear(
                local_entity_emb)  # batch_size, max_local_entity, entity_dim

        # label propagation on entities
        for i in range(self.num_layer):
            # get linear transformation functions for each layer
            q2e_linear = getattr(self, 'q2e_linear' + str(i))
            d2e_linear = getattr(self, 'd2e_linear' + str(i))
            e2q_linear = getattr(self, 'e2q_linear' + str(i))
            e2d_linear = getattr(self, 'e2d_linear' + str(i))
            e2e_linear = getattr(self, 'e2e_linear' + str(i))
            if self.use_kb:
                kb_self_linear = getattr(self, 'kb_self_linear' + str(i))
                kb_head_linear = getattr(self, 'kb_head_linear' + str(i))
                kb_tail_linear = getattr(self, 'kb_tail_linear' + str(i))

            # start propagation
            # next_local_entity_emb实际就是更新的entity embedding
            # 边训练,边生成图边根据图得到embedding。
            # 实际上就是把question, kb, document,entity四者结合成为新的entity embedding
            next_local_entity_emb = local_entity_emb

            # STEP 1: propagate from question, documents, and facts to entities 相当于建立entity的知识图
            # 方法:document->entity: e Personalized PageRank. fact -> entity :DrQA
            # query_node_emb:是query通过lstm
            # question -> entity
            q2e_emb = q2e_linear(self.linear_drop(query_node_emb)).expand(
                batch_size, max_local_entity,
                self.entity_dim)  # batch_size, max_local_entity, entity_dim
            next_local_entity_emb = torch.cat(
                (next_local_entity_emb, q2e_emb),
                dim=2)  # batch_size, max_local_entity, entity_dim * 2

            # document -> entity d2e_emb
            if self.use_doc:
                pagerank_e2d = sparse_bmm(
                    entity_pos_mat.transpose(1, 2),
                    pagerank_d.unsqueeze(dim=2) / e2d_out_dim
                )  # batch_size, max_relevant_doc * max_document_word, 1
                pagerank_e2d = pagerank_e2d.view(batch_size, max_relevant_doc,
                                                 max_document_word)
                pagerank_e2d = torch.sum(pagerank_e2d,
                                         dim=2)  # batch_size, max_relevant_doc
                pagerank_e2d = pagerank_e2d / torch.clamp(
                    torch.sum(pagerank_e2d, dim=1, keepdim=True),
                    min=VERY_SMALL_NUMBER)  # batch_size, max_relevant_doc
                pagerank_e2d = pagerank_e2d.unsqueeze(dim=2).expand(
                    batch_size, max_relevant_doc, max_document_word
                )  # batch_size, max_relevant_doc, max_document_word
                pagerank_e2d = pagerank_e2d.contiguous().view(
                    batch_size, max_relevant_doc * max_document_word
                )  # batch_size, max_relevant_doc * max_document_word
                pagerank_d2e = sparse_bmm(
                    entity_pos_mat, pagerank_e2d.unsqueeze(
                        dim=2))  # batch_size, max_local_entity, 1
                pagerank_d2e = pagerank_d2e.squeeze(
                    dim=2)  # batch_size, max_local_entity
                pagerank_d2e = pagerank_d2e / torch.clamp(
                    torch.sum(pagerank_d2e, dim=1, keepdim=True),
                    min=VERY_SMALL_NUMBER)
                pagerank_d = self.pagerank_lambda * pagerank_d2e + (
                    1 - self.pagerank_lambda) * pagerank_d

                d2e_emb = sparse_bmm(
                    entity_pos_mat,
                    d2e_linear(
                        document_textual_emb.view(
                            batch_size, max_relevant_doc * max_document_word,
                            self.entity_dim)))
                d2e_emb = d2e_emb * pagerank_d.unsqueeze(
                    dim=2)  # batch_size, max_local_entity, entity_dim

            # fact -> entity f2e_emb
            if self.use_kb:
                e2f_emb = self.relu(
                    kb_self_linear(local_fact_emb) + sparse_bmm(
                        entity2fact_mat,
                        kb_head_linear(self.linear_drop(local_entity_emb)))
                )  # batch_size, max_fact, entity_dim
                e2f_softmax_normalized = W_tilde.unsqueeze(dim=2) * sparse_bmm(
                    entity2fact_mat,
                    (pagerank_f /
                     e2f_softmax).unsqueeze(dim=2))  # batch_size, max_fact, 1
                e2f_emb = e2f_emb * e2f_softmax_normalized  # batch_size, max_fact, entity_dim
                f2e_emb = self.relu(
                    kb_self_linear(local_entity_emb) +
                    sparse_bmm(fact2entity_mat,
                               kb_tail_linear(self.linear_drop(e2f_emb))))

                pagerank_f = self.pagerank_lambda * sparse_bmm(
                    fact2entity_mat, e2f_softmax_normalized).squeeze(dim=2) + (
                        1 - self.pagerank_lambda
                    ) * pagerank_f  # batch_size, max_local_entity

            # STEP 2: combine embeddings from fact and documents
            # 将第一步形成的知识图embedding合并
            if self.use_doc and self.use_kb:
                next_local_entity_emb = torch.cat(
                    (next_local_entity_emb,
                     self.fact_scale * f2e_emb + d2e_emb),
                    dim=2)  # batch_size, max_local_entity, entity_dim * 3
            elif self.use_doc:
                next_local_entity_emb = torch.cat(
                    (next_local_entity_emb, d2e_emb),
                    dim=2)  # batch_size, max_local_entity, entity_dim * 3
            elif self.use_kb:
                next_local_entity_emb = torch.cat(
                    (next_local_entity_emb, self.fact_scale * f2e_emb),
                    dim=2)  # batch_size, max_local_entity, entity_dim * 3
            else:
                assert False, 'using neither kb nor doc ???'

            # STEP 3: propagate from entities to update question, documents, and facts
            # entity信息传播。多层的话
            # entity -> document
            if self.use_doc:
                e2d_emb = torch.bmm(
                    d2e_adj_mat.transpose(1, 2),
                    e2d_linear(
                        self.linear_drop(next_local_entity_emb / e2d_out_dim)
                    ))  # batch_size, max_relevant_doc, entity_dim
                e2d_emb = sparse_bmm(
                    entity_pos_mat.transpose(1, 2),
                    e2d_linear(self.linear_drop(next_local_entity_emb))
                )  # batch_size, max_relevant_doc * max_document_word, entity_dim
                e2d_emb = e2d_emb.view(
                    batch_size, max_relevant_doc, max_document_word,
                    self.entity_dim
                )  # batch_size, max_relevant_doc, max_document_word, entity_dim
                document_textual_emb = document_textual_emb + e2d_emb  # batch_size, max_relevant_doc, max_document_word, entity_dim
                document_textual_emb = document_textual_emb.view(
                    -1, max_document_word, self.entity_dim)
                document_textual_emb, _ = read_padded(
                    self.doc_info_carrier,
                    self.lstm_drop(document_textual_emb),
                    document_mask.view(-1, max_document_word)
                )  # batch_size * max_relevant_doc, max_document_word, entity_dim
                document_textual_emb = document_textual_emb[:, :, 0:self.
                                                            entity_dim] + document_textual_emb[:, :,
                                                                                               self
                                                                                               .
                                                                                               entity_dim:]
                document_textual_emb = document_textual_emb.contiguous().view(
                    batch_size, max_relevant_doc, max_document_word,
                    self.entity_dim)

            # entity -> query
            query_node_emb = torch.bmm(
                pagerank_f.unsqueeze(dim=1),
                e2q_linear(self.linear_drop(next_local_entity_emb)))

            # update entity
            local_entity_emb = self.relu(
                e2e_linear(self.linear_drop(next_local_entity_emb))
            )  # batch_size, max_local_entity, entity_dim

        # calculate loss and make prediction
        score = self.score_func(self.linear_drop(local_entity_emb)).squeeze(
            dim=2)  # batch_size, max_local_entity
        loss = self.bce_loss_logits(score, answer_dist)

        score = score + (1 - local_entity_mask) * VERY_NEG_NUMBER
        pred_dist = self.sigmoid(score) * local_entity_mask
        pred = torch.max(score, dim=1)[1]

        return loss, pred, pred_dist