def forward(self, entity_reprs, pairs): batch_size = pairs.shape[0] entity_pairs = util.batch_index(entity_reprs, pairs) entity_pair_repr = entity_pairs.view(batch_size, entity_pairs.shape[1], -1) return entity_pair_repr
def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor, entities: torch.tensor, entity_masks: torch.tensor, rel_entity_pairs: torch.tensor, rel_sample_masks: torch.tensor, rel_entity_pair_mp: torch.tensor, rel_mention_pair_ep: torch.tensor, rel_mention_pairs: torch.tensor, rel_ctx_masks: torch.tensor, rel_pair_masks: torch.tensor, rel_token_distances: torch.tensor, rel_sentence_distances: torch.tensor, entity_types: torch.tensor, max_spans: bool = None, max_rel_pairs: bool = None, inference: bool = False, *args, **kwargs): context_masks = context_masks.float() mention_masks = mention_masks.float() entity_masks = entity_masks.float() h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state'] mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans) entity_reprs = self.entity_representation(mention_reprs, entities, entity_masks) entity_pair_reprs = self.entity_pair_representation(entity_reprs, rel_entity_pairs) rel_entity_types = util.batch_index(entity_types, rel_entity_pairs) rel_clf = self.relation_classification(entity_pair_reprs, h, mention_reprs, rel_entity_pair_mp, rel_mention_pair_ep, rel_mention_pairs, rel_ctx_masks, rel_pair_masks, rel_token_distances, rel_sentence_distances, rel_entity_types, max_pairs=max_rel_pairs) if inference: rel_clf = torch.sigmoid(rel_clf) rel_clf[rel_clf < self._rel_threshold] = 0 rel_clf *= rel_sample_masks.unsqueeze(-1) return dict(rel_clf=rel_clf)
def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor, mention_sizes: torch.tensor, entities: torch.tensor, entity_masks: torch.tensor, coref_mention_pairs: torch.tensor, rel_entity_pairs: torch.tensor, rel_mention_pairs: torch.tensor, rel_ctx_masks: torch.tensor, rel_entity_pair_mp: torch.tensor, rel_mention_pair_ep: torch.tensor, rel_pair_masks: torch.tensor, rel_token_distances: torch.tensor, rel_sentence_distances: torch.tensor, entity_types: torch.tensor, coref_eds: torch.tensor, max_spans: bool = None, max_coref_pairs: bool = None, max_rel_pairs: bool = None, *args, **kwargs): res = self._forward_train_common(encodings, context_masks, mention_masks, mention_sizes, entities, entity_masks, coref_mention_pairs, coref_eds, max_coref_pairs=max_coref_pairs, max_spans=max_spans) h, mention_reprs, entity_reprs, mention_clf, entity_clf, coref_clf = res entity_pair_reprs = self.entity_pair_representation( entity_reprs, rel_entity_pairs) rel_entity_types = util.batch_index(entity_types, rel_entity_pairs) rel_clf = self.relation_classification(entity_pair_reprs, h, mention_reprs, rel_entity_pair_mp, rel_mention_pair_ep, rel_mention_pairs, rel_ctx_masks, rel_pair_masks, rel_token_distances, rel_sentence_distances, rel_entity_types, max_pairs=max_rel_pairs) return dict(mention_clf=mention_clf, entity_clf=entity_clf, coref_clf=coref_clf, rel_clf=rel_clf)
def forward(self, mention_reprs, entities, entity_masks): mention_clusters = util.batch_index(mention_reprs, entities) entity_masks = entity_masks.unsqueeze(-1) # max pool entity clusters m = (entity_masks == 0).float() * (-1e30) mention_spans_pool = mention_clusters + m entity_reprs = mention_spans_pool.max(dim=2)[0] entity_reprs = self.dropout(entity_reprs) return entity_reprs
def forward(self, entity_reprs, rel_entity_types, pairs): rel_entity_types = self.entity_embeddings(rel_entity_types) batch_size = pairs.shape[0] entity_pairs = util.batch_index(entity_reprs, pairs) entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1) rel_entity_types = rel_entity_types.view(rel_entity_types.shape[0], rel_entity_types.shape[1], -1) entity_pair_repr = self.entity_pair_linear(torch.cat([entity_pairs, rel_entity_types], dim=2)) entity_pair_repr = self.dropout(torch.relu(entity_pair_repr)) return entity_pair_repr
def _create_mention_pair_representations(self, entity_pair_reprs, chunk_rel_mention_pair_ep, rel_mention_pairs, rel_ctx_masks, rel_token_distances, rel_sentence_distances, mention_reprs, h): rel_token_distances = self.token_distance_embeddings( rel_token_distances) rel_sentence_distances = self.sentence_distance_embeddings( rel_sentence_distances) rel_mention_pair_reprs = util.batch_index(mention_reprs, rel_mention_pairs) s = rel_mention_pair_reprs.shape rel_mention_pair_reprs = rel_mention_pair_reprs.view(s[0], s[1], -1) # ctx max pooling m = ((rel_ctx_masks == 0).float() * (-1e30)).unsqueeze(-1) rel_ctx = m + h # max pooling rel_ctx, rel_ctx_indices = rel_ctx.max(dim=2) # set the context vector of neighboring or adjacent spans to zero rel_ctx[rel_ctx_masks.bool().any(-1) == 0] = 0 entity_pair_reprs = util.batch_index(entity_pair_reprs, chunk_rel_mention_pair_ep) local_repr = torch.cat([ rel_ctx, rel_mention_pair_reprs, entity_pair_reprs, rel_token_distances, rel_sentence_distances ], dim=2) local_repr = self.dropout(self.pair_linear(local_repr)) return local_repr
def _forward_inference(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor, mention_sizes: torch.tensor, mention_spans: torch.tensor, mention_sample_masks: torch.tensor, max_spans=None, max_coref_pairs=None, *args, **kwargs): res = self._forward_inference_common(encodings, context_masks, mention_masks, mention_sizes, mention_spans, mention_sample_masks, max_coref_pairs=max_coref_pairs, max_spans=max_spans) (h, mention_reprs, entity_reprs, clusters, entity_sample_masks, mention_pair_sample_masks, clusters_sample_masks, mention_clf, entity_clf, coref_clf) = res # create entity pairs rel_entity_pairs, rel_sample_masks = misc.create_rel_global_entity_pairs( entity_reprs, entity_sample_masks) rel_sample_masks = rel_sample_masks.float() # create entity pair representations entity_types = entity_clf.argmax(dim=-1) rel_entity_types = util.batch_index(entity_types, rel_entity_pairs) entity_pair_reprs = self.entity_pair_representation( entity_reprs, rel_entity_types, rel_entity_pairs) # classify relations rel_clf = self.relation_classification(entity_pair_reprs) # thresholding and masking mention_clf, coref_clf, entity_clf, rel_clf = self._apply_thresholds( mention_clf, coref_clf, entity_clf, rel_clf, mention_sample_masks, mention_pair_sample_masks, entity_sample_masks, rel_sample_masks) return dict(mention_clf=mention_clf, coref_clf=coref_clf, entity_clf=entity_clf, rel_clf=rel_clf, clusters=clusters, clusters_sample_masks=clusters_sample_masks, rel_entity_pairs=rel_entity_pairs)
def _classify_corefs(self, mention_reprs, coref_mention_pairs, coref_eds): batch_size = coref_mention_pairs.shape[0] # get pairs of entity mention representations mention_pairs1 = util.batch_index(mention_reprs, coref_mention_pairs) mention_pairs = mention_pairs1.view(batch_size, mention_pairs1.shape[1], -1) coref_repr = torch.cat([mention_pairs, coref_eds], dim=2) coref_repr = torch.relu(self.coref_linear(coref_repr)) coref_repr = self.dropout(coref_repr) # classify coref candidates chunk_coref_logits = self.coref_classifier(coref_repr) chunk_coref_logits = chunk_coref_logits.squeeze(dim=-1) return chunk_coref_logits
def _classify_relations(self, rel_mention_pair_reprs, rel_entity_pair_mp, rel_pair_masks, rel_entity_types): local_repr = util.batch_index(rel_mention_pair_reprs, rel_entity_pair_mp) local_repr += (rel_pair_masks.unsqueeze(-1) == 0).float() * (-1e30) local_repr = local_repr.max(dim=2)[0] rel_entity_types = self.entity_type_embeddings(rel_entity_types) rel_entity_types = rel_entity_types.view(rel_entity_types.shape[0], rel_entity_types.shape[1], -1) rel_repr = torch.cat([local_repr, rel_entity_types], dim=2) rel_repr = self.dropout(torch.relu(self.rel_linear(rel_repr))) # classify relation candidates rel_logits = self.rel_classifier(rel_repr) rel_logits = rel_logits.squeeze(dim=-1) return rel_logits
def _forward_inference(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor, mention_sizes: torch.tensor, mention_spans: torch.tensor, mention_sample_masks: torch.tensor, mention_sent_indices: torch.tensor, mention_orig_spans: torch.tensor, max_spans: bool = None, max_coref_pairs: bool = None, max_rel_pairs: bool = None, *args, **kwargs): res = self._forward_inference_common(encodings, context_masks, mention_masks, mention_sizes, mention_spans, mention_sample_masks, max_spans=max_spans, max_coref_pairs=max_coref_pairs) (h, mention_reprs, entity_reprs, clusters, entity_sample_masks, mention_pair_sample_masks, clusters_sample_masks, mention_clf, entity_clf, coref_clf) = res # create entity pairs (rel_entity_pair_mp, rel_mention_pair_ep, rel_entity_pairs, rel_mention_pairs, rel_ctx_masks, rel_token_distances, rel_sentence_distances, rel_mention_pair_masks) = misc.create_local_entity_pairs( clusters, clusters_sample_masks, mention_spans, mention_sent_indices, mention_orig_spans, context_masks.shape[-1]) rel_sample_masks = rel_mention_pair_masks.any(dim=-1) # create entity pair representations entity_pair_reprs = self.entity_pair_representation( entity_reprs, rel_entity_pairs) # classify relations entity_types = entity_clf.argmax(dim=-1) rel_entity_types = util.batch_index(entity_types, rel_entity_pairs) rel_clf = self.relation_classification(entity_pair_reprs, h, mention_reprs, rel_entity_pair_mp, rel_mention_pair_ep, rel_mention_pairs, rel_ctx_masks, rel_mention_pair_masks, rel_token_distances, rel_sentence_distances, rel_entity_types, max_pairs=max_rel_pairs) # thresholding and masking mention_clf, coref_clf, entity_clf, rel_clf = self._apply_thresholds( mention_clf, coref_clf, entity_clf, rel_clf, mention_sample_masks, mention_pair_sample_masks, entity_sample_masks, rel_sample_masks) return dict(mention_clf=mention_clf, coref_clf=coref_clf, entity_clf=entity_clf, rel_clf=rel_clf, clusters=clusters, clusters_sample_masks=clusters_sample_masks, rel_entity_pairs=rel_entity_pairs)