def forward(self, batch): entities, answer_dist = batch # numpy to tensor entities = use_cuda( Variable(torch.from_numpy(entities).type('torch.LongTensor'), requires_grad=False)) answer_dist = use_cuda( Variable(torch.from_numpy(answer_dist).type('torch.LongTensor'), requires_grad=False)) # entity embedding entity_emb = self.entity_embedding(entities) if self.has_entity_kge: entity_emb = torch.cat( (entity_emb, self.entity_kge(entities)), dim=2) # batch_size, max_local_entity, word_dim + kge_dim if self.word_dim != self.entity_dim: entity_emb = self.entity_linear1(self.linear_drop( entity_emb)) # batch_size, max_local_entity, entity_dim entity_emb = self.relu( self.entity_linear2(self.linear_drop(entity_emb))) score = self.relu(self.entity_linear3(self.linear_drop(entity_emb))) score = score.squeeze() loss = self.cross_loss(score, answer_dist) pred_dist = self.sigmoid(score) pred = torch.max(pred_dist, dim=1)[1] return loss, pred, pred_dist
def get_model(cfg, num_kb_relation, num_entities, num_vocab): word_emb_file = None if cfg[ 'word_emb_file'] is None else cfg['data_folder'] + cfg['word_emb_file'] entity_emb_file = None if cfg['entity_emb_file'] is None else cfg[ 'data_folder'] + cfg['entity_emb_file'] entity_kge_file = None if cfg['entity_kge_file'] is None else cfg[ 'data_folder'] + cfg['entity_kge_file'] relation_emb_file = None if cfg['relation_emb_file'] is None else cfg[ 'data_folder'] + cfg['relation_emb_file'] relation_kge_file = None if cfg['relation_kge_file'] is None else cfg[ 'data_folder'] + cfg['relation_kge_file'] my_model = use_cuda( GraftNet(word_emb_file, entity_emb_file, entity_kge_file, relation_emb_file, relation_kge_file, cfg['num_layer'], num_kb_relation, num_entities, num_vocab, cfg['entity_dim'], cfg['word_dim'], cfg['kge_dim'], cfg['pagerank_lambda'], cfg['fact_scale'], cfg['lstm_dropout'], cfg['linear_dropout'], cfg['use_kb'], cfg['use_doc'])) if cfg['load_model_file'] is not None: print('loading model from', cfg['load_model_file']) pretrained_model_states = torch.load(cfg['load_model_file']) if word_emb_file is not None: del pretrained_model_states['word_embedding.weight'] if entity_emb_file is not None: del pretrained_model_states['entity_embedding.weight'] my_model.load_state_dict(pretrained_model_states, strict=False) return my_model
def gen(model, input, size=100, temp=0.5): """ Generate data from the model """ with torch.no_grad(): data = [] # Init input if util.use_cuda(): input = input.cuda() model = model.cuda() input = Variable(input) # Print the input print('IN: [', end='', flush=True) for i in input: print(i.item(), end=', ', flush=True) print(']', end='', flush=True) print() # Get generated data for i in range(size): output = model(input[None, :]) c = sample(output[0, -1, :], temp) #print(" SAMPLE: ", c) #print(str(chr(max(32, c))), end='', flush=True) data.append(c.item()) # Make next prediction informed by this prediction input = torch.cat([input[1:], c[None]], dim=0) #print(data[:30]) #print() #print(pitches) return data
def get_model(entity2id): my_model = use_cuda( Classifier(entity_emb_file, entity_kge_file, len(entity2id), word_dim, kge_dim, entity_dim, linear_dropout)) if load_model is not None: print("loading model from", load_model) pretrained_model_states = torch.load(load_model) if entity_emb_file is not None: del pretrained_model_states['entity_embedding.weight'] del pretrained_model_states['cross_loss.weight'] my_model.load_state_dict(pretrained_model_states, strict=False) return my_model
def get_model(cfg, num_kb_relation, num_entities, num_vocab): # word_emb_file = None if cfg['word_emb_file'] is None else cfg['data_folder'] + cfg['word_emb_file'] # entity_emb_file = None if cfg['entity_emb_file'] is None else cfg['data_folder'] + cfg['entity_emb_file'] # entity_kge_file = None if cfg['entity_kge_file'] is None else cfg['data_folder'] + cfg['entity_kge_file'] # relation_emb_file = None if cfg['relation_emb_file'] is None else cfg['data_folder'] + cfg['relation_emb_file'] # relation_kge_file = None if cfg['relation_kge_file'] is None else cfg['data_folder'] + cfg['relation_kge_file'] #my_model = cuda(GraftNet(word_emb_file, entity_emb_file, entity_kge_file, relation_emb_file, relation_kge_file, cfg['num_layer'], num_kb_relation, num_entities, num_vocab, cfg['entity_dim'], cfg['word_dim'], cfg['kge_dim'], cfg['pagerank_lambda'], cfg['fact_scale'], cfg['lstm_dropout'], cfg['linear_dropout'], cfg['use_kb'], cfg['use_doc'])) # cuda_condition = torch.cuda.is_available() and with_cuda # self.device = torch.device("cuda:0" if cuda_condition else "cpu") my_model = use_cuda(GraftNet(num_kb_relation, num_entities, num_vocab, cfg)) if cfg['load_model_file'] is not None: print('loading model from', cfg['load_model_file']) pretrained_model_states = torch.load(cfg['load_model_file']) if cfg['word_emb_file'] is not None: del pretrained_model_states['word_embedding.weight'] if cfg['entity_emb_file'] is not None: del pretrained_model_states['entity_embedding.weight'] my_model.load_state_dict(pretrained_model_states, strict=False) return my_model
import torch from torch.autograd import Variable import torch.nn as nn from util import use_cuda import numpy as np answer_type = 6 weight = use_cuda( Variable(torch.Tensor([1.26, 6.89, 23.66, 149.49, 104.40, 505.41]))) class Classifier(nn.Module): def __init__(self, pretrained_entity_emb_file, pretrained_entity_kge_file, num_entity, word_dim, kge_dim, entity_dim, linear_dropout): super(Classifier, self).__init__() self.has_entity_kge = True self.num_entity = num_entity self.entity_dim = entity_dim self.word_dim = word_dim self.kge_dim = kge_dim # initialize entity embedding self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=word_dim, padding_idx=num_entity) if pretrained_entity_emb_file is not None: self.entity_embedding.weight = nn.Parameter( torch.from_numpy( np.pad(np.load(pretrained_entity_emb_file), ((0, 1), (0, 0)), 'constant')).type('torch.FloatTensor'))
def train(n_heads=8, depth=4, seq_length=32, n_tokens=256, emb_size=128, n_batches=500, batch_size=64, test_every=50, lr=0.0001, warmup=100, seed=-1, data_sub=1000, output_path="genmodel.pt"): """ Train the model and save it to output_path """ # Seed the network if (seed < 0): seed = random.randint(0, 1000000) print("Using seed: ", seed) else: torch.manual_seed(seed) # Load training data data_train, data_valid = get_data() losses = [] # Create the model model = tf.GenTransformer(emb=emb_size, n_heads=n_heads, depth=depth, seq_length=seq_length, n_tokens=n_tokens) if util.use_cuda(): model = model.cuda() # Optimizer opt = torch.optim.Adam(model.parameters(), lr) # Train over batches of random sequences for i in tqdm.trange(n_batches - 1): # tqdm is a nice progress bar # Warming up learning rate by linearly increasing to the provided learning rate if lr > 0 and i < warmup: lr = max((lr / warmup) * i, 1e-10) opt.lr = lr # Prevent gradient accumulation opt.zero_grad() # Sample batch of random subsequences starts = torch.randint(size=(batch_size, ), low=0, high=data_train.size(0) - seq_length - 1) seqs_source = [ data_train[start:start + seq_length] for start in starts ] # The target is the same as the source sequence except one character ahead seqs_target = [ data_train[start + 1:start + seq_length + 1] for start in starts ] source = torch.cat([s[None, :] for s in seqs_source], dim=0).to(torch.long) target = torch.cat([s[None, :] for s in seqs_target], dim=0).to(torch.long) # Get cuda if util.use_cuda(): source, target = source.cuda(), target.cuda() source, target = Variable(source), Variable(target) # Initialize the output output = model(source) # Get the loss loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean') loss.backward() losses.append(loss.item()) # Clip the gradients nn.utils.clip_grad_norm_(model.parameters(), 1) # Perform optimization step opt.step() # Validate every so often, compute compression then generate if i != 0 and (i % test_every == 0 or i == n_batches - 1): # TODO sort of arbitrary, make this rigorous upto = data_valid.size(0) if i == n_batches - 1 else 100 data_sub = data_valid[:upto] # with torch.no_grad(): bits = 0.0 # When this buffer is full we run it through the model batch = [] for current in range(data_sub.size(0)): fr = max(0, current - seq_length) to = current + 1 context = data_sub[fr:to].to(torch.long) # If the data doesnt fit the sequence length pad it if context.size(0) < seq_length + 1: pad = torch.zeros(size=(seq_length + 1 - context.size(0), ), dtype=torch.long) context = torch.cat([pad, context], dim=0) assert context.size(0) == seq_length + 1 # Get cuda if util.use_cuda(): context = context.cuda() # Fill the batch batch.append(context[None, :]) # Check if the batch is full if len(batch ) == batch_size or current == data_sub.size(0) - 1: # Run through model b = len(batch) all = torch.cat(batch, dim=0) source = all[:, :-1] # Input target = all[:, -1] # Target values # output = model(source) # Get probabilities and convert to bits lnprobs = output[torch.arange(b, device=util.device()), -1, target] log2probs = lnprobs * math.log2(math.e) # For logging bits += log2probs.sum() # Empty batch buffer batch = [] # Print validation performance bits_per_byte = abs(bits / data_sub.size(0)) print(f' epoch {i}: {bits_per_byte:.4} bits per byte') print("Loss:", loss.item()) # Monitor progress by generating data based on the validation data seedfr = random.randint(0, data_valid.size(0) - seq_length) input = data_valid[seedfr:seedfr + seq_length].to(torch.long) output_valid = gen(model, input) print("OUT:", output_valid[:30]) util.save_model(model, output_path) return losses # Save the model when we're done training it # print("Finished training. Model saved to", output_path)
def init_hidden(self, num_layer, batch_size, hidden_size): return (use_cuda( Variable(torch.zeros(num_layer, batch_size, hidden_size))), use_cuda( Variable(torch.zeros(num_layer, batch_size, hidden_size))))
def forward(self, batch): """ :local_entity: global_id of each entity (batch_size, max_local_entity) :q2e_adj_mat: adjacency matrices (dense) (batch_size, max_local_entity, 1) :kb_adj_mat: adjacency matrices (sparse) (batch_size, max_fact, max_local_entity), (batch_size, max_local_entity, max_fact) :kb_fact_rel: (batch_size, max_fact) :query_text: a list of words in the query (batch_size, max_query_word) :document_text: (batch_size, max_relevant_doc, max_document_word) :entity_pos: sparse entity_pos_mat (batch_size, max_local_entity, max_relevant_doc * max_document_word) :answer_dist: an distribution over local_entity (batch_size, max_local_entity) """ local_entity, q2e_adj_mat, kb_adj_mat, kb_fact_rel, query_text, document_text, entity_pos, answer_dist = batch batch_size, max_local_entity = local_entity.shape _, max_relevant_doc, max_document_word = document_text.shape _, max_fact = kb_fact_rel.shape # numpy to tensor local_entity = use_cuda( Variable(torch.from_numpy(local_entity).type('torch.LongTensor'), requires_grad=False)) local_entity_mask = use_cuda( (local_entity != self.num_entity).type('torch.FloatTensor')) if self.use_kb: kb_fact_rel = use_cuda( Variable( torch.from_numpy(kb_fact_rel).type('torch.LongTensor'), requires_grad=False)) query_text = use_cuda( Variable(torch.from_numpy(query_text).type('torch.LongTensor'), requires_grad=False)) query_mask = use_cuda( (query_text != self.num_word).type('torch.FloatTensor')) if self.use_doc: document_text = use_cuda( Variable( torch.from_numpy(document_text).type('torch.LongTensor'), requires_grad=False)) document_mask = use_cuda( (document_text != self.num_word).type('torch.FloatTensor')) answer_dist = use_cuda( Variable(torch.from_numpy(answer_dist).type('torch.FloatTensor'), requires_grad=False)) # normalized adj matrix pagerank_f = use_cuda( Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=True)).squeeze( dim=2) # batch_size, max_local_entity pagerank_d = use_cuda( Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=False)).squeeze( dim=2) # batch_size, max_local_entity q2e_adj_mat = use_cuda( Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=False)) # batch_size, max_local_entity, 1 assert pagerank_f.requires_grad == True assert pagerank_d.requires_grad == False # encode query query_word_emb = self.word_embedding( query_text) # batch_size, max_query_word, word_dim query_hidden_emb, (query_node_emb, _) = self.node_encoder( self.lstm_drop(query_word_emb), self.init_hidden(1, batch_size, self.entity_dim)) # 1, batch_size, entity_dim query_node_emb = query_node_emb.squeeze(dim=0).unsqueeze( dim=1) # batch_size, 1, entity_dim query_rel_emb = query_node_emb # batch_size, 1, entity_dim if self.use_kb: # build kb_adj_matrix from sparse matrix (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = kb_adj_mat entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e]) entity2fact_val = torch.FloatTensor(e2f_val) entity2fact_mat = use_cuda( torch.sparse.FloatTensor( entity2fact_index, entity2fact_val, torch.Size([batch_size, max_fact, max_local_entity ]))) # batch_size, max_fact, max_local_entity fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f]) fact2entity_val = torch.FloatTensor(f2e_val) fact2entity_mat = use_cuda( torch.sparse.FloatTensor( fact2entity_index, fact2entity_val, torch.Size([batch_size, max_local_entity, max_fact]))) # load fact embedding local_fact_emb = self.relation_embedding( kb_fact_rel) # batch_size, max_fact, 2 * word_dim if self.has_relation_kge: local_fact_emb = torch.cat( (local_fact_emb, self.relation_kge(kb_fact_rel)), dim=2) # batch_size, max_fact, 2 * word_dim + kge_dim local_fact_emb = self.relation_linear( local_fact_emb) # batch_size, max_fact, entity_dim # attention fact2question div = float(np.sqrt(self.entity_dim)) fact2query_sim = torch.bmm( query_hidden_emb, local_fact_emb.transpose( 1, 2)) / div # batch_size, max_query_word, max_fact fact2query_sim = self.softmax_d1( fact2query_sim + (1 - query_mask.unsqueeze(dim=2)) * VERY_NEG_NUMBER) # batch_size, max_query_word, max_fact fact2query_att = torch.sum( fact2query_sim.unsqueeze(dim=3) * query_hidden_emb.unsqueeze(dim=2), dim=1) # batch_size, max_fact, entity_dim W = torch.sum(fact2query_att * local_fact_emb, dim=2) / div # batch_size, max_fact W_max = torch.max(W, dim=1, keepdim=True)[0] # batch_size, 1 W_tilde = torch.exp(W - W_max) # batch_size, max_fact e2f_softmax = sparse_bmm(entity2fact_mat.transpose(1, 2), W_tilde.unsqueeze(dim=2)).squeeze( dim=2) # batch_size, max_local_entity e2f_softmax = torch.clamp(e2f_softmax, min=VERY_SMALL_NUMBER) e2f_out_dim = use_cuda( Variable(torch.sum(entity2fact_mat.to_dense(), dim=1), requires_grad=False)) # batch_size, max_local_entity # build entity_pos matrix if self.use_doc: entity_pos_dim_batch, entity_pos_dim_entity, entity_pos_dim_doc_by_word, entity_pos_value = entity_pos entity_pos_index = torch.LongTensor([ entity_pos_dim_batch, entity_pos_dim_entity, entity_pos_dim_doc_by_word ]) entity_pos_val = torch.FloatTensor(entity_pos_value) entity_pos_mat = use_cuda( torch.sparse.FloatTensor( entity_pos_index, entity_pos_val, torch.Size([ batch_size, max_local_entity, max_relevant_doc * max_document_word ])) ) # batch_size, max_local_entity, max_relevant_doc * max_document_word d2e_adj_mat = torch.sum( entity_pos_mat.to_dense().view(batch_size, max_local_entity, max_relevant_doc, max_document_word), dim=3) # batch_size, max_local_entity, max_relevant_doc d2e_adj_mat = use_cuda(Variable(d2e_adj_mat, requires_grad=False)) e2d_out_dim = torch.sum( torch.sum(entity_pos_mat.to_dense().view( batch_size, max_local_entity, max_relevant_doc, max_document_word), dim=3), dim=2, keepdim=True) # batch_size, max_local_entity, 1 e2d_out_dim = use_cuda(Variable(e2d_out_dim, requires_grad=False)) e2d_out_dim = torch.clamp(e2d_out_dim, min=VERY_SMALL_NUMBER) d2e_out_dim = torch.sum( torch.sum(entity_pos_mat.to_dense().view( batch_size, max_local_entity, max_relevant_doc, max_document_word), dim=3), dim=1, keepdim=True) # batch_size, 1, max_relevant_doc d2e_out_dim = use_cuda(Variable(d2e_out_dim, requires_grad=False)) d2e_out_dim = torch.clamp(d2e_out_dim, min=VERY_SMALL_NUMBER).transpose(1, 2) # encode document document_textual_emb = self.word_embedding( document_text.view(batch_size * max_relevant_doc, max_document_word) ) # batch_size * max_relevant_doc, max_document_word, entity_dim document_textual_emb, (document_node_emb, _) = read_padded( self.bi_text_encoder, self.lstm_drop(document_textual_emb), document_mask.view(-1, max_document_word) ) # batch_size * max_relevant_doc, max_document_word, entity_dim document_textual_emb = document_textual_emb[:, :, 0:self. entity_dim] + document_textual_emb[:, :, self . entity_dim:] document_textual_emb = document_textual_emb.contiguous().view( batch_size, max_relevant_doc, max_document_word, self.entity_dim) document_node_emb = ( document_node_emb[0, :, :] + document_node_emb[1, :, :] ).view(batch_size, max_relevant_doc, self.entity_dim) # batch_size, max_relevant_doc, entity_dim # load entity embedding local_entity_emb = self.entity_embedding( local_entity) # batch_size, max_local_entity, word_dim if self.has_entity_kge: local_entity_emb = torch.cat( (local_entity_emb, self.entity_kge(local_entity)), dim=2) # batch_size, max_local_entity, word_dim + kge_dim if self.word_dim != self.entity_dim: local_entity_emb = self.entity_linear( local_entity_emb) # batch_size, max_local_entity, entity_dim # label propagation on entities for i in range(self.num_layer): # get linear transformation functions for each layer q2e_linear = getattr(self, 'q2e_linear' + str(i)) d2e_linear = getattr(self, 'd2e_linear' + str(i)) e2q_linear = getattr(self, 'e2q_linear' + str(i)) e2d_linear = getattr(self, 'e2d_linear' + str(i)) e2e_linear = getattr(self, 'e2e_linear' + str(i)) if self.use_kb: kb_self_linear = getattr(self, 'kb_self_linear' + str(i)) kb_head_linear = getattr(self, 'kb_head_linear' + str(i)) kb_tail_linear = getattr(self, 'kb_tail_linear' + str(i)) # start propagation # next_local_entity_emb实际就是更新的entity embedding # 边训练,边生成图边根据图得到embedding。 # 实际上就是把question, kb, document,entity四者结合成为新的entity embedding next_local_entity_emb = local_entity_emb # STEP 1: propagate from question, documents, and facts to entities 相当于建立entity的知识图 # 方法:document->entity: e Personalized PageRank. fact -> entity :DrQA # query_node_emb:是query通过lstm # question -> entity q2e_emb = q2e_linear(self.linear_drop(query_node_emb)).expand( batch_size, max_local_entity, self.entity_dim) # batch_size, max_local_entity, entity_dim next_local_entity_emb = torch.cat( (next_local_entity_emb, q2e_emb), dim=2) # batch_size, max_local_entity, entity_dim * 2 # document -> entity d2e_emb if self.use_doc: pagerank_e2d = sparse_bmm( entity_pos_mat.transpose(1, 2), pagerank_d.unsqueeze(dim=2) / e2d_out_dim ) # batch_size, max_relevant_doc * max_document_word, 1 pagerank_e2d = pagerank_e2d.view(batch_size, max_relevant_doc, max_document_word) pagerank_e2d = torch.sum(pagerank_e2d, dim=2) # batch_size, max_relevant_doc pagerank_e2d = pagerank_e2d / torch.clamp( torch.sum(pagerank_e2d, dim=1, keepdim=True), min=VERY_SMALL_NUMBER) # batch_size, max_relevant_doc pagerank_e2d = pagerank_e2d.unsqueeze(dim=2).expand( batch_size, max_relevant_doc, max_document_word ) # batch_size, max_relevant_doc, max_document_word pagerank_e2d = pagerank_e2d.contiguous().view( batch_size, max_relevant_doc * max_document_word ) # batch_size, max_relevant_doc * max_document_word pagerank_d2e = sparse_bmm( entity_pos_mat, pagerank_e2d.unsqueeze( dim=2)) # batch_size, max_local_entity, 1 pagerank_d2e = pagerank_d2e.squeeze( dim=2) # batch_size, max_local_entity pagerank_d2e = pagerank_d2e / torch.clamp( torch.sum(pagerank_d2e, dim=1, keepdim=True), min=VERY_SMALL_NUMBER) pagerank_d = self.pagerank_lambda * pagerank_d2e + ( 1 - self.pagerank_lambda) * pagerank_d d2e_emb = sparse_bmm( entity_pos_mat, d2e_linear( document_textual_emb.view( batch_size, max_relevant_doc * max_document_word, self.entity_dim))) d2e_emb = d2e_emb * pagerank_d.unsqueeze( dim=2) # batch_size, max_local_entity, entity_dim # fact -> entity f2e_emb if self.use_kb: e2f_emb = self.relu( kb_self_linear(local_fact_emb) + sparse_bmm( entity2fact_mat, kb_head_linear(self.linear_drop(local_entity_emb))) ) # batch_size, max_fact, entity_dim e2f_softmax_normalized = W_tilde.unsqueeze(dim=2) * sparse_bmm( entity2fact_mat, (pagerank_f / e2f_softmax).unsqueeze(dim=2)) # batch_size, max_fact, 1 e2f_emb = e2f_emb * e2f_softmax_normalized # batch_size, max_fact, entity_dim f2e_emb = self.relu( kb_self_linear(local_entity_emb) + sparse_bmm(fact2entity_mat, kb_tail_linear(self.linear_drop(e2f_emb)))) pagerank_f = self.pagerank_lambda * sparse_bmm( fact2entity_mat, e2f_softmax_normalized).squeeze(dim=2) + ( 1 - self.pagerank_lambda ) * pagerank_f # batch_size, max_local_entity # STEP 2: combine embeddings from fact and documents # 将第一步形成的知识图embedding合并 if self.use_doc and self.use_kb: next_local_entity_emb = torch.cat( (next_local_entity_emb, self.fact_scale * f2e_emb + d2e_emb), dim=2) # batch_size, max_local_entity, entity_dim * 3 elif self.use_doc: next_local_entity_emb = torch.cat( (next_local_entity_emb, d2e_emb), dim=2) # batch_size, max_local_entity, entity_dim * 3 elif self.use_kb: next_local_entity_emb = torch.cat( (next_local_entity_emb, self.fact_scale * f2e_emb), dim=2) # batch_size, max_local_entity, entity_dim * 3 else: assert False, 'using neither kb nor doc ???' # STEP 3: propagate from entities to update question, documents, and facts # entity信息传播。多层的话 # entity -> document if self.use_doc: e2d_emb = torch.bmm( d2e_adj_mat.transpose(1, 2), e2d_linear( self.linear_drop(next_local_entity_emb / e2d_out_dim) )) # batch_size, max_relevant_doc, entity_dim e2d_emb = sparse_bmm( entity_pos_mat.transpose(1, 2), e2d_linear(self.linear_drop(next_local_entity_emb)) ) # batch_size, max_relevant_doc * max_document_word, entity_dim e2d_emb = e2d_emb.view( batch_size, max_relevant_doc, max_document_word, self.entity_dim ) # batch_size, max_relevant_doc, max_document_word, entity_dim document_textual_emb = document_textual_emb + e2d_emb # batch_size, max_relevant_doc, max_document_word, entity_dim document_textual_emb = document_textual_emb.view( -1, max_document_word, self.entity_dim) document_textual_emb, _ = read_padded( self.doc_info_carrier, self.lstm_drop(document_textual_emb), document_mask.view(-1, max_document_word) ) # batch_size * max_relevant_doc, max_document_word, entity_dim document_textual_emb = document_textual_emb[:, :, 0:self. entity_dim] + document_textual_emb[:, :, self . entity_dim:] document_textual_emb = document_textual_emb.contiguous().view( batch_size, max_relevant_doc, max_document_word, self.entity_dim) # entity -> query query_node_emb = torch.bmm( pagerank_f.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb))) # update entity local_entity_emb = self.relu( e2e_linear(self.linear_drop(next_local_entity_emb)) ) # batch_size, max_local_entity, entity_dim # calculate loss and make prediction score = self.score_func(self.linear_drop(local_entity_emb)).squeeze( dim=2) # batch_size, max_local_entity loss = self.bce_loss_logits(score, answer_dist) score = score + (1 - local_entity_mask) * VERY_NEG_NUMBER pred_dist = self.sigmoid(score) * local_entity_mask pred = torch.max(score, dim=1)[1] return loss, pred, pred_dist