def forward(self, x, adj, length=None): batch_size, node_num, feature_dim = x.shape h = to_gpu(Variable(torch.from_numpy(x), requires_grad=False)).float() length_mask = None if length is not None: lengths_var = to_gpu( Variable(torch.from_numpy(length), requires_grad=False)).long() # batch_size * node_num length_mask = sequence_mask(lengths_var, node_num) class_mask = length_mask.unsqueeze(2).expand( batch_size, node_num, 2) class_mask = class_mask.float() # adj: batch * node_num * node_num adj = to_gpu(Variable(torch.from_numpy(adj), requires_grad=False)).float() h = self.gc_layer(h, adj, mask=length_mask) # h: batch * node_num * hidden if self.class_ln: h = self.ln_inp(h) h = F.dropout(h, self.drop_out_rate, training=self.training) output = self.classifer(h) # batch_size * node_num * self._num_class output = masked_softmax(output, mask=class_mask) return output
def getNeighborMask(self, num_mentions, dim): batch_size, cand_num = num_mentions.shape # batch * cand_num margin_col = to_gpu(torch.zeros(1, cand_num)) right_mask = to_gpu(torch.from_numpy(num_mentions)).float() left_mask = torch.cat([margin_col, right_mask[:-1, :]], dim=0) # (batch * cand_num) * dim right_mask_expand = right_mask.view(-1).unsqueeze(1).expand( batch_size * cand_num, dim) left_mask_expand = left_mask.view(-1).unsqueeze(1).expand( batch_size * cand_num, dim) return left_mask_expand, right_mask_expand
def getCandidateEmbedding(self, candidates, candidates_sense=None): candidates = to_gpu( Variable(torch.from_numpy(candidates), volatile=not self.training)).long() cand_entity_emb = self.run_embed(candidates, 1) cand_sense_emb = None cand_mu_emb = None if candidates_sense is not None and self._has_sense: candidates_sense = to_gpu( Variable(torch.from_numpy(candidates_sense), volatile=not self.training)).long() cand_sense_emb = self.run_embed(candidates_sense, 2) cand_mu_emb = self.run_embed(candidates_sense, 3) return cand_entity_emb, cand_sense_emb, cand_mu_emb
def getCandidateSimilarity(self, embeddings, candidate_embeddings, default_sims=None): cand_entity_emb, cand_sense_emb, cand_mu_emb = candidate_embeddings entity_emb, sense_emb, mu_emb = embeddings if default_sims is None: batch_size, _ = entity_emb.size() default_sims = to_gpu( Variable(torch.FloatTensor([DEFAULT_SIM] * batch_size).unsqueeze(1), requires_grad=False)) cand_entity_emb_expand = cand_entity_emb.unsqueeze(1) sim1 = torch.bmm(cand_entity_emb_expand, entity_emb.unsqueeze(2)).squeeze(2) sim2 = DEFAULT_SIM sim3 = DEFAULT_SIM if cand_sense_emb is not None and cand_mu_emb is not None: cand_sense_emb_expand = cand_sense_emb.unsqueeze(1) cand_mu_emb_expand = cand_mu_emb.unsqueeze(1) sim2 = torch.bmm(cand_sense_emb_expand, sense_emb.unsqueeze(2)).squeeze(2) sim3 = torch.bmm(cand_mu_emb_expand, mu_emb.unsqueeze(2)).squeeze(2) return sim1, sim2, sim3
def getNeighEmb(self, mstr_emb, cand_num, neighbor_window, left_mask, right_mask): margin_col = to_gpu( Variable(torch.zeros(cand_num, self._dim), requires_grad=False)) # left_neighs: (batch_size*cand_num) * window * dim tmp_left_neigh_list = [] tmp_left_neigh_list.append( self.leftMvNeigh(mstr_emb, cand_num, margin_col, left_mask)) for i in range(neighbor_window - 1): tmp_left_neigh_list.append( self.leftMvNeigh(tmp_left_neigh_list[i], cand_num, margin_col, left_mask)) for i, neigh in enumerate(tmp_left_neigh_list): tmp_left_neigh_list[i] = tmp_left_neigh_list[i].unsqueeze(1) left_neighs = torch.cat(tmp_left_neigh_list, dim=1) tmp_right_neigh_list = [] tmp_right_neigh_list.append( self.rightMvNeigh(mstr_emb, cand_num, margin_col, right_mask)) for i in range(neighbor_window - 1): tmp_right_neigh_list.append( self.rightMvNeigh(tmp_right_neigh_list[i], cand_num, margin_col, right_mask)) for i, neigh in enumerate(tmp_right_neigh_list): tmp_right_neigh_list[i] = tmp_right_neigh_list[i].unsqueeze(1) right_neighs = torch.cat(tmp_right_neigh_list, dim=1) # neigh_emb: (batch_size*cand_num) * 2window * dim neigh_emb = torch.cat((left_neighs, right_neighs), dim=1) # neigh_emb: (batch_size*cand_num) * dim neigh_emb = torch.mean(neigh_emb, dim=1) return neigh_emb
def getNeighCandidates(self, emb, window, num_mentions): batch_size, cand_num = num_mentions.shape _, dim = emb.size() left_mask, right_mask = self.getNeighborMask(num_mentions, dim) margin_col = to_gpu(torch.zeros(cand_num, dim)) left_list = [] # (batch * cand) * dim left_list.append( self.leftNeighbor(emb, cand_num, margin_col, left_mask)) for i in range(window - 1): left_list.append( self.leftNeighbor(left_list[i], cand_num, margin_col, left_mask)) for i in range(window): left_list[i] = self.getExpandNeighCandidates( left_list[i], batch_size, cand_num, dim) # (batch * cand) * (window*cand) * dim left_cands = torch.cat(left_list, dim=1) right_list = [] right_list.append( self.rightNeighbor(emb, cand_num, margin_col, right_mask)) for i in range(window - 1): right_list.append( self.rightNeighbor(right_list[i], cand_num, margin_col, right_mask)) for i in range(window): right_list[i] = self.getExpandNeighCandidates( right_list[i], batch_size, cand_num, dim) # (batch * cand) * (window*cand) * dim right_cands = torch.cat(right_list, dim=1) # (batch * cand) * (cand_num*window*2) * dim neigh_cands = torch.cat((left_cands, right_cands), dim=1) return neigh_cands
def getTokenEmbedding(self, tokens, candidate_embeddings=None): tokens = to_gpu( Variable(torch.from_numpy(tokens), volatile=not self.training)).long() if candidate_embeddings is not None: cand_entity_emb, cand_sense_emb, cand_mu_emb = candidate_embeddings entity_emb = self.getEmbFeatures(tokens, q_emb=cand_entity_emb) else: entity_emb = self.getEmbFeatures(tokens) sense_emb = entity_emb mu_emb = entity_emb return entity_emb, sense_emb, mu_emb
def getGraphSample(self, e, num_mentions, entity_vocab, id2wiki_vocab, only_one=False): ent_label_vocab = dict([(entity_vocab[id], id2wiki_vocab[id]) for id in entity_vocab if id in id2wiki_vocab]) ent_label_vocab[0] = 'PAD' ent_label_vocab[1] = 'UNK' batch_size, cand_num = e.shape # graph, (batch * cand) * (cand_num*window*2+1) adj = self._adj.data # neighbors, (batch * cand) * (cand_num*window*2) e_var = to_gpu( Variable(torch.from_numpy(e).view(-1).unsqueeze(1), requires_grad=False).float()) neighbors = self.getNeighCandidates(e_var, self._neighbor_cand_window, num_mentions).data.squeeze() c_idx = -1 docs = [] doc_edges = [] is_doc_end = False for i in range(batch_size): for j in range(cand_num): c_idx += 1 if e[i][j] in [0, 1]: continue label = ent_label_vocab[e[i][j]] # edges edges = adj[c_idx] nodes = neighbors[c_idx] tmp_len = len(edges) - 1 for k in range(tmp_len): if edges[k] > 0 and nodes[k] not in [0, 1]: n_label = ent_label_vocab[nodes[k]] doc_edges.append([label, n_label, edges[k]]) # doc if num_mentions[i][j] == 0: is_doc_end = True if is_doc_end: is_doc_end = False doc_line = "Graph: \n" + "\n".join([ "{}<-{}->{}".format(edge[0], edge[2], edge[1]) for edge in doc_edges ]) + '\n' docs.append(doc_line) if only_one: return docs del doc_edges[:] return docs
def getNeighborMentionEmbeddings(self, ment_emb, neighbor_window, num_mentions): batch_size, cand_num = num_mentions.shape _, dim = ment_emb.size() left_mask, right_mask = self.getNeighborMask(num_mentions, dim) margin_col = to_gpu(torch.zeros(cand_num, dim)) neibor_ment_entity_emb = self.getNeighborMentionEmbeddingsForCandidate( ment_emb, margin_col, cand_num, neighbor_window, left_mask, right_mask) neibor_ment_entity_emb = Variable(neibor_ment_entity_emb, requires_grad=False) neibor_ment_sense_emb = neibor_ment_entity_emb neibor_ment_mu_emb = neibor_ment_entity_emb return neibor_ment_entity_emb, neibor_ment_sense_emb, neibor_ment_mu_emb
def forward(self, contexts1, base_feature, candidates, mention_tokens, contexts2=None, candidates_sense=None, num_mentions=None, length=None): batch_size, cand_num, _ = base_feature.shape features = [] # to gpu base_feature = to_gpu( Variable(torch.from_numpy(base_feature[:, :, -1]), requires_grad=False)).float() return base_feature.squeeze()
def buildGraph(self, cand_emb, window, num_mentions, thred=0.0): batch_size, cand_num = num_mentions.shape # (batch * cand) * (cand_num*window*2) * dim neigh_cands = self.getNeighCandidates(cand_emb, window, num_mentions) # (batch * cand) * (cand_num*window*2) * dim cand_emb_expand = cand_emb.unsqueeze(1).expand(batch_size * cand_num, 2 * window * cand_num, self._dim) # (batch * cand) * (cand_num*window*2) adj = torch.clamp( F.cosine_similarity(cand_emb_expand, neigh_cands, dim=2), thred, 1) if thred > 0.0: adj[adj <= thred] = 0.0 # add self connection margin_col = to_gpu(torch.ones(batch_size * cand_num, 1)) # size: (batch * cand) * (cand_num*window*2+1) adj = torch.cat((adj * self._rho, margin_col), dim=1) # normalize adj = Variable(F.normalize(adj, p=1, dim=1), requires_grad=False) return adj
def getNeighborMentionEmbeddings(self, mention_embeddings, neighbor_window, num_mentions): batch_size, cand_num = num_mentions.shape entity_emb, sense_emb, mu_emb = mention_embeddings _, dim = entity_emb.size() left_mask, right_mask = self.getNeighborMask(num_mentions, dim) margin_col = to_gpu( Variable(torch.zeros(cand_num, dim), requires_grad=False)) neibor_ment_entity_emb = self.getNeighborMentionEmbeddingsForCandidate( entity_emb, margin_col, cand_num, neighbor_window, left_mask, right_mask) neibor_ment_sense_emb = None neibor_ment_mu_emb = None if sense_emb is not None: neibor_ment_sense_emb = self.getNeighborMentionEmbeddingsForCandidate( sense_emb, margin_col, cand_num, neighbor_window, left_mask, right_mask) if mu_emb is not None: neibor_ment_mu_emb = self.getNeighborMentionEmbeddingsForCandidate( mu_emb, margin_col, cand_num, neighbor_window, left_mask, right_mask) return neibor_ment_entity_emb, neibor_ment_sense_emb, neibor_ment_mu_emb
def forward(self, contexts1, base_feature, candidates, m_strs, contexts2=None, candidates_sense=None, num_mentions=None, length=None): batch_size, cand_num, _ = base_feature.shape # to gpu base_feature = to_gpu(Variable(torch.from_numpy(base_feature))).float() contexts1 = to_gpu(Variable(torch.from_numpy(contexts1))).long() candidates = to_gpu(Variable(torch.from_numpy(candidates))).long() m_strs = to_gpu(Variable(torch.from_numpy(m_strs))).long() # candidate mask if length is not None: lengths_var = to_gpu( Variable(torch.from_numpy(length), requires_grad=False)).long() # batch_size * cand_num length_mask = sequence_mask(lengths_var, cand_num).float() # mention context mask has_neighbors = False if self._neighbor_window > 0 and num_mentions is not None: # batch * cand margin_col = to_gpu( Variable(torch.zeros(1, cand_num), requires_grad=False)) right_neigh_mask = to_gpu( Variable(torch.from_numpy(num_mentions), requires_grad=False)).float() left_neigh_mask = torch.cat([margin_col, right_neigh_mask[:-1, :]], dim=0) right_neigh_mask_expand = right_neigh_mask.view(-1).unsqueeze( 1).expand(batch_size * cand_num, self._dim) left_neigh_mask_expand = left_neigh_mask.view(-1).unsqueeze( 1).expand(batch_size * cand_num, self._dim) has_neighbors = True has_context2 = False if contexts2 is not None and self._use_contexts2: contexts2 = to_gpu(Variable(torch.from_numpy(contexts2))).long() has_context2 = True has_sense = False if candidates_sense is not None and self._has_sense: candidates_sense = to_gpu( Variable(torch.from_numpy(candidates_sense))).long() has_sense = True # get emb, (batch * cand) * dim cand_entity_emb = self.run_embed(candidates, 1) f1_entity_emb = self.getEmbFeatures(contexts1, q_emb=cand_entity_emb) if has_sense: cand_sense_emb = self.run_embed(candidates_sense, 2) cand_mu_emb = self.run_embed(candidates_sense, 3) f1_sense_emb = self.getEmbFeatures(contexts1, q_emb=cand_sense_emb) f1_mu_emb = self.getEmbFeatures(contexts1, q_emb=cand_mu_emb) if has_context2: f2_entity_emb = self.getEmbFeatures(contexts2, q_emb=cand_entity_emb) if has_sense: f2_sense_emb = self.getEmbFeatures(contexts2, q_emb=cand_sense_emb) f2_mu_emb = self.getEmbFeatures(contexts2, q_emb=cand_mu_emb) # get contextual similarity, (batch * cand) * contextual_sim cand_entity_emb_expand = cand_entity_emb.unsqueeze(1) if has_sense: cand_sense_emb_expand = cand_sense_emb.unsqueeze(1) cand_mu_emb_expand = cand_mu_emb.unsqueeze(1) # get mention string similarity ms_entity_emb = self.getEmbFeatures(m_strs, q_emb=cand_entity_emb) if has_sense: ms_sense_emb = self.getEmbFeatures(m_strs, q_emb=cand_sense_emb) ms_mu_emb = self.getEmbFeatures(m_strs, q_emb=cand_mu_emb) m_sim1 = torch.bmm(cand_entity_emb_expand, ms_entity_emb.unsqueeze(2)).squeeze(2) if has_sense: m_sim2 = torch.bmm(cand_sense_emb_expand, ms_sense_emb.unsqueeze(2)).squeeze(2) m_sim3 = torch.bmm(cand_mu_emb_expand, ms_mu_emb.unsqueeze(2)).squeeze(2) if has_neighbors: # (batch * cand_num) * dim neigh_entity_emb = self.getNeighEmb(ms_entity_emb, cand_num, self._neighbor_window, left_neigh_mask_expand, right_neigh_mask_expand) n_sim1 = torch.bmm(cand_entity_emb_expand, neigh_entity_emb.unsqueeze(2)).squeeze(2) if has_sense: neigh_sense_emb = self.getNeighEmb(ms_sense_emb, cand_num, self._neighbor_window, left_neigh_mask_expand, right_neigh_mask_expand) n_sim2 = torch.bmm(cand_sense_emb_expand, neigh_sense_emb.unsqueeze(2)).squeeze(2) neigh_mu_emb = self.getNeighEmb(ms_mu_emb, cand_num, self._neighbor_window, left_neigh_mask_expand, right_neigh_mask_expand) n_sim3 = torch.bmm(cand_mu_emb_expand, neigh_mu_emb.unsqueeze(2)).squeeze(2) # entity: context1 sim1 = torch.bmm(cand_entity_emb_expand, f1_entity_emb.unsqueeze(2)).squeeze(2) if has_sense: # sense : context1 sim2 = torch.bmm(cand_sense_emb_expand, f1_sense_emb.unsqueeze(2)).squeeze(2) # mu : context1 sim3 = torch.bmm(cand_mu_emb_expand, f1_mu_emb.unsqueeze(2)).squeeze(2) # entity: context2 if has_context2: sim4 = torch.bmm(cand_entity_emb_expand, f2_entity_emb.unsqueeze(2)).squeeze(2) if has_sense: # sense : context2 sim5 = torch.bmm(cand_sense_emb_expand, f2_sense_emb.unsqueeze(2)).squeeze(2) # mu : context2 sim6 = torch.bmm(cand_mu_emb_expand, f2_mu_emb.unsqueeze(2)).squeeze(2) # feature vec : batch * cand * feature_dim # feature dim: base_dim + 2*dim + 2 + 1(if has entity) + # (2+word_dim)(if has contexts) + 1(if has context2 and has entity) base_feature = base_feature.view(batch_size * cand_num, -1) h = torch.cat( (base_feature, cand_entity_emb, f1_entity_emb, sim1, m_sim1), dim=1) if has_sense: h = torch.cat((h, sim2, sim3, m_sim2, m_sim3), dim=1) if has_context2: h = torch.cat((h, sim4, f2_entity_emb), dim=1) if has_sense: h = torch.cat((h, sim5, sim6), dim=1) if has_neighbors: h = torch.cat((h, n_sim1), dim=1) if has_sense: h = torch.cat((h, n_sim2, n_sim3), dim=1) h = self.mlp_classifier(h, length=length_mask.view(-1)) # reshape, batch_size * cand_num h = h.view(batch_size, -1) output = masked_softmax(h, mask=length_mask) return output
def forward(self, contexts1, base_feature, candidates, mention_tokens, contexts2=None, candidates_sense=None, num_mentions=None, length=None): batch_size, cand_num, _ = base_feature.shape features = [] # to gpu base_feature = to_gpu( Variable(torch.from_numpy(base_feature), requires_grad=False)).float() base_feature = base_feature.view(batch_size * cand_num, -1) features.append(base_feature) # candidate mask length_mask = None if length is not None: lengths_var = to_gpu( Variable(torch.from_numpy(length), requires_grad=False)).long() # batch_size * cand_num length_mask = sequence_mask(lengths_var, cand_num).float() # get emb, (batch * cand) * dim candidate_embeddings = self.getCandidateEmbedding( candidates, candidates_sense) cand_emb1, cand_emb2, cand_emb3 = candidate_embeddings # get context emb context1_emb = self.getTokenEmbedding( contexts1, candidate_embeddings=candidate_embeddings if self._use_att else None) # get contextual similarity, (batch * cand) * contextual_sim con1_sims = self.getCandidateSimilarity(context1_emb, candidate_embeddings) features.extend(con1_sims) con2_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM con2_emb_cand1 = None if self._use_contexts2 and contexts2 is not None: context2_emb = self.getTokenEmbedding( contexts2, candidate_embeddings=candidate_embeddings if self._use_att else None) con2_emb_cand1, con2_emb_cand2, con2_emb_cand3 = context2_emb # get contextual similarity, (batch * cand) * contextual_sim con2_sims = self.getCandidateSimilarity(context2_emb, candidate_embeddings) features.extend(con2_sims) # get mention string similarity, todo: no att ment_embs = self.getTokenEmbedding( mention_tokens, candidate_embeddings=candidate_embeddings) mention_sims = self.getCandidateSimilarity(ment_embs, candidate_embeddings) features.extend(mention_sims) # neibor mention string similarity neigh_ment_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM if self._neighbor_window > 0 and num_mentions is not None: # (batch * cand_num) * dim neigh_ment_embs = self.getNeighborMentionEmbeddings( ment_embs, self._neighbor_window, num_mentions) neigh_ment_sims = self.getCandidateSimilarity( neigh_ment_embs, candidate_embeddings) features.extend(neigh_ment_sims) # neighbor candidates # (batch * cand) * 1 * (cand_num*window*2+1) self._adj = self.buildGraph(cand_emb1, self._neighbor_window, num_mentions, thred=self._thred).unsqueeze(1) # feature vec : (batch * cand) * feature_dim h = torch.cat(features, dim=1) if self._use_embedding_feature: con1_emb_cand1, con1_emb_cand2, con1_emb_cand3 = context1_emb h = torch.cat((h, cand_emb1, con1_emb_cand1), dim=1) if con2_emb_cand1 is not None: h = torch.cat((h, con2_emb_cand1), dim=1) if self._gc_ln: h = self.ln_inp(h) for i in range(self._num_layers): w = getattr(self, 'w{}'.format(i)) b = getattr(self, 'b{}'.format(i)) dim = getattr(self, 'd{}'.format(i)) h = h.matmul(w) # (batch_size * cand_num) * (2*window*cand_num+1) * f_dim h = self.getExpandFeature(h, self._neighbor_window, num_mentions) h = torch.bmm(self._adj, h).squeeze(1) if b is not None: h = h + b # h: (batch_size * cand_num) * feature_dim if length_mask is not None: mask = length_mask.view(-1).unsqueeze(1).expand( batch_size * cand_num, dim) h = h * mask h = F.relu(h) h = F.dropout(h, self._dropout_rate, training=self.training) h = h.matmul(self.gc_classifier_w) # (batch_size * cand_num) * (2*window*cand_num+1) * f_dim h = self.getExpandFeature(h, self._neighbor_window, num_mentions) h = torch.bmm(self._adj, h).squeeze(1) if self.gc_classifier_b is not None: h = h + self.gc_classifier_b # reshape, batch_size * cand_num h = h.squeeze().view(batch_size, -1) output = masked_softmax(h, mask=length_mask) return output
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.NcelEntry() for _ in range(trainer.step, FLAGS.training_steps): if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break # set model in training mode model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() doc_batch = next(training_data_iter) batch = get_batch(doc_batch, FLAGS.local_context_window, use_lr_context=FLAGS.use_lr_context, split_by_sent=FLAGS.split_by_sent) base, context1, context2, m_strs, cids, cids_sense, num_candidates, num_mentions, y = batch # check training data # inspectBatch(batch, vocabulary, doc_batch) total_candidates = num_candidates.sum() # Reset cached gradients. trainer.optimizer_zero_grad() # Run model. output: batch_size * cand_num output = model(context1, base, cids, m_strs, contexts2=context2, candidates_sense=cids_sense, num_mentions=num_mentions, length=num_candidates) target = torch.from_numpy(y).long() # Calculate accuracy. total_mentions, actual_mentions, actual_correct = \ ComputeAccuracy(output.data, target, doc_batch) # Calculate loss. loss = nn.CrossEntropyLoss()(output, to_gpu( Variable(target, requires_grad=False))) # loss = nn.MultiLabelMarginLoss()(output, to_gpu(Variable(target, volatile=False))) # Backward pass. loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm([ param for name, param in model.named_parameters() if name not in [ "word_embed.embed.weight", "entity_embed.embed.weight", "sense_embed.embed.weight", "mu_embed.embed.weight" ] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() end = time.time() total_time = end - start doc_accs = [ correct / float(actual_mentions[i]) for i, correct in enumerate(actual_correct) ] A.add('mention_prec', sum(actual_correct) / float(sum(actual_mentions))) A.add('doc_prec', sum(doc_accs) / float(len(doc_accs))) A.add('total_candidates', total_candidates) A.add('total_time', total_time) if trainer.step % FLAGS.statistics_interval_steps == 0: A.add('total_cost', loss.data[0]) stats(model, trainer, A, log_entry) should_log = True progress_bar.finish() if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True # note: at most tow eval set due to training recording best eval_metrics = [] for index, eval_set in enumerate(eval_iterators): eval_metrics.append( evaluate(FLAGS, model, eval_set, log_entry, logger, show_sample=FLAGS.show_sample, vocabulary=vocabulary, eval_index=index)) trainer.new_accuracy(eval_metrics) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps) finalStats(trainer, logger)
def forward(self, contexts1, base_feature, candidates, mention_tokens, contexts2=None, candidates_sense=None, num_mentions=None, length=None): batch_size, cand_num, _ = base_feature.shape features = [] # to gpu base_feature = to_gpu( Variable(torch.from_numpy(base_feature), requires_grad=False)).float() base_feature = base_feature.view(batch_size * cand_num, -1) features.append(base_feature) # candidate mask length_mask = None if length is not None: lengths_var = to_gpu( Variable(torch.from_numpy(length), requires_grad=False)).long() # batch_size * cand_num length_mask = sequence_mask(lengths_var, cand_num).float() # get emb, (batch * cand) * dim candidate_embeddings = self.getCandidateEmbedding( candidates, candidates_sense) cand_emb1, cand_emb2, cand_emb3 = candidate_embeddings # get context emb context1_emb = self.getTokenEmbedding( contexts1, candidate_embeddings=candidate_embeddings if self._use_att else None) # get contextual similarity, (batch * cand) * contextual_sim con1_sims = self.getCandidateSimilarity(context1_emb, candidate_embeddings) features.extend(con1_sims) con2_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM con2_emb_cand1 = None if self._use_contexts2 and contexts2 is not None: context2_emb = self.getTokenEmbedding( contexts2, candidate_embeddings=candidate_embeddings if self._use_att else None) con2_emb_cand1, con2_emb_cand2, con2_emb_cand3 = context2_emb # get contextual similarity, (batch * cand) * contextual_sim con2_sims = self.getCandidateSimilarity(context2_emb, candidate_embeddings) features.extend(con2_sims) # get mention string similarity, # ment_embs = self.getTokenEmbedding(mention_tokens, candidate_embeddings=candidate_embeddings) ment_embs = self.getTokenEmbedding(mention_tokens) mention_sims = self.getCandidateSimilarity(ment_embs, candidate_embeddings) features.extend(mention_sims) # neibor mention string similarity neigh_ment_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM if self._neighbor_ment_window > 0 and num_mentions is not None: ment_entity_embs, _, _ = ment_embs # (batch * cand_num) * dim neigh_ment_embs = self.getNeighborMentionEmbeddings( ment_entity_embs.data, self._neighbor_ment_window, num_mentions) neigh_ment_sims = self.getCandidateSimilarity( neigh_ment_embs, candidate_embeddings) features.extend(neigh_ment_sims) # neighbor candidates # (batch * cand) * (cand_num*window*2+1) self._adj = self.buildGraph(cand_emb1.data, self._neighbor_cand_window, num_mentions, thred=self._thred).unsqueeze(1) # feature vec : (batch * cand) * feature_dim f_vec = torch.cat(features, dim=1) if self._use_embedding_feature: con1_emb_cand1, con1_emb_cand2, con1_emb_cand3 = context1_emb f_vec = torch.cat((f_vec, cand_emb1, con1_emb_cand1), dim=1) if con2_emb_cand1 is not None: f_vec = torch.cat((f_vec, con2_emb_cand1), dim=1) # mlp classify gc_input = self.mlp_classifier( f_vec, length=length_mask.view(-1)) * self._temperature if self._res_num > 0: # (batch_size * cand_num) * dim for i in range(self._res_num): l = getattr(self, 'l{}'.format(i)) h = l(gc_input, self._adj, num_mentions, mask=length_mask) # skip connection sk_layer = getattr(self, 'sk{}'.format(i)) if sk_layer is not None: h = h + sk_layer(gc_input) else: h = h + gc_input gc_input = h if self.classifier is not None: gc_input = self.classifier(gc_input) # reshape, batch_size * cand_num h = gc_input.squeeze().view(batch_size, -1) output = masked_softmax(h, mask=length_mask) return output
def forward(self, contexts1, base_feature, candidates, contexts2=None, candidates_entity=None, length=None): batch_size, cand_num, _ = base_feature.shape # to gpu base_feature = to_gpu(Variable(torch.from_numpy(base_feature))).float() contexts1 = to_gpu(Variable(torch.from_numpy(contexts1))).long() candidates = to_gpu(Variable(torch.from_numpy(candidates))).long() has_context2 = False if contexts2 is not None and self._use_contexts2: contexts2 = to_gpu(Variable(torch.from_numpy(contexts2))).long() has_context2 = True has_entity = False if candidates_entity is not None and self._use_entity: candidates_entity = to_gpu( Variable(torch.from_numpy(candidates_entity))).long() has_entity = True # get emb, (batch * cand) * dim cand_emb = self.run_embed(candidates, 1) cand_mu_emb = self.run_embed(candidates, 2) f1_sense_emb = self.getEmbFeatures(contexts1, q_emb=cand_emb) f1_mu_emb = self.getEmbFeatures(contexts1, q_emb=cand_mu_emb) if has_entity: cand_entity_emb = self.run_embed(candidates_entity, 3) f1_entity_emb = self.getEmbFeatures(contexts1, q_emb=cand_entity_emb) if has_context2: f2_sense_emb = self.getEmbFeatures(contexts2, q_emb=cand_emb) f2_mu_emb = self.getEmbFeatures(contexts2, q_emb=cand_mu_emb) if has_entity: f2_entity_emb = self.getEmbFeatures(contexts2, q_emb=cand_entity_emb) # get contextual similarity, (batch * cand) * contextual_sim cand_emb_expand = cand_emb.unsqueeze(1) cand_mu_emb_expand = cand_mu_emb.unsqueeze(1) # sense : context1 sim1 = torch.bmm(cand_emb_expand, f1_sense_emb.unsqueeze(2)).squeeze(2) # mu : context1 sim2 = torch.bmm(cand_mu_emb_expand, f1_mu_emb.unsqueeze(2)).squeeze(2) # entity: context1 if has_entity: cand_entity_emb_expand = cand_entity_emb.unsqueeze(1) sim3 = torch.bmm(cand_entity_emb_expand, f1_entity_emb.unsqueeze(2)).squeeze(2) # sense : context2 if has_context2: sim4 = torch.bmm(cand_emb_expand, f2_sense_emb.unsqueeze(2)).squeeze(2) # mu : context2 sim5 = torch.bmm(cand_mu_emb_expand, f2_mu_emb.unsqueeze(2)).squeeze(2) # entity: context2 if has_entity: sim6 = torch.bmm(cand_entity_emb_expand, f2_entity_emb.unsqueeze(2)).squeeze(2) # feature vec : batch * cand * feature_dim # feature dim: base_dim + sense_dim + word_dim + 2 + 1(if has entity) + # (2+word_dim)(if has contexts) + 1(if has context2 and has entity) base_feature = base_feature.view(batch_size * cand_num, -1) h = torch.cat((base_feature, cand_emb, f1_sense_emb, sim1, sim2), dim=1) if has_entity: h = torch.cat((h, sim3), dim=1) if has_context2: h = torch.cat((h, sim4, sim5, f2_sense_emb), dim=1) if has_entity: h = torch.cat((h, sim6), dim=1) h = self.mlp_classifier(h) # reshape, batch_size * cand_num h = h.view(batch_size, -1) if length is not None: lengths_var = to_gpu( Variable(torch.from_numpy(length), requires_grad=False)).long() # batch_size * cand_num length_mask = sequence_mask(lengths_var, cand_num).float() output = masked_softmax(h, mask=length_mask) return output