def forward(self, entity_reprs, pairs):
        batch_size = pairs.shape[0]

        entity_pairs = util.batch_index(entity_reprs, pairs)
        entity_pair_repr = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)

        return entity_pair_repr
Example #2
0
    def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor,
                entities: torch.tensor, entity_masks: torch.tensor,
                rel_entity_pairs: torch.tensor, rel_sample_masks: torch.tensor,
                rel_entity_pair_mp: torch.tensor, rel_mention_pair_ep: torch.tensor,
                rel_mention_pairs: torch.tensor, rel_ctx_masks: torch.tensor, rel_pair_masks: torch.tensor,
                rel_token_distances: torch.tensor, rel_sentence_distances: torch.tensor, entity_types: torch.tensor,
                max_spans: bool = None, max_rel_pairs: bool = None, inference: bool = False, *args, **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()
        entity_masks = entity_masks.float()

        h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans)
        entity_reprs = self.entity_representation(mention_reprs, entities, entity_masks)
        entity_pair_reprs = self.entity_pair_representation(entity_reprs, rel_entity_pairs)

        rel_entity_types = util.batch_index(entity_types, rel_entity_pairs)
        rel_clf = self.relation_classification(entity_pair_reprs, h, mention_reprs,
                                               rel_entity_pair_mp, rel_mention_pair_ep,
                                               rel_mention_pairs, rel_ctx_masks, rel_pair_masks,
                                               rel_token_distances, rel_sentence_distances, rel_entity_types,
                                               max_pairs=max_rel_pairs)

        if inference:
            rel_clf = torch.sigmoid(rel_clf)
            rel_clf[rel_clf < self._rel_threshold] = 0
            rel_clf *= rel_sample_masks.unsqueeze(-1)

        return dict(rel_clf=rel_clf)
Example #3
0
    def _forward_train(self,
                       encodings: torch.tensor,
                       context_masks: torch.tensor,
                       mention_masks: torch.tensor,
                       mention_sizes: torch.tensor,
                       entities: torch.tensor,
                       entity_masks: torch.tensor,
                       coref_mention_pairs: torch.tensor,
                       rel_entity_pairs: torch.tensor,
                       rel_mention_pairs: torch.tensor,
                       rel_ctx_masks: torch.tensor,
                       rel_entity_pair_mp: torch.tensor,
                       rel_mention_pair_ep: torch.tensor,
                       rel_pair_masks: torch.tensor,
                       rel_token_distances: torch.tensor,
                       rel_sentence_distances: torch.tensor,
                       entity_types: torch.tensor,
                       coref_eds: torch.tensor,
                       max_spans: bool = None,
                       max_coref_pairs: bool = None,
                       max_rel_pairs: bool = None,
                       *args,
                       **kwargs):
        res = self._forward_train_common(encodings,
                                         context_masks,
                                         mention_masks,
                                         mention_sizes,
                                         entities,
                                         entity_masks,
                                         coref_mention_pairs,
                                         coref_eds,
                                         max_coref_pairs=max_coref_pairs,
                                         max_spans=max_spans)
        h, mention_reprs, entity_reprs, mention_clf, entity_clf, coref_clf = res

        entity_pair_reprs = self.entity_pair_representation(
            entity_reprs, rel_entity_pairs)

        rel_entity_types = util.batch_index(entity_types, rel_entity_pairs)
        rel_clf = self.relation_classification(entity_pair_reprs,
                                               h,
                                               mention_reprs,
                                               rel_entity_pair_mp,
                                               rel_mention_pair_ep,
                                               rel_mention_pairs,
                                               rel_ctx_masks,
                                               rel_pair_masks,
                                               rel_token_distances,
                                               rel_sentence_distances,
                                               rel_entity_types,
                                               max_pairs=max_rel_pairs)

        return dict(mention_clf=mention_clf,
                    entity_clf=entity_clf,
                    coref_clf=coref_clf,
                    rel_clf=rel_clf)
Example #4
0
    def forward(self, mention_reprs, entities, entity_masks):
        mention_clusters = util.batch_index(mention_reprs, entities)
        entity_masks = entity_masks.unsqueeze(-1)

        # max pool entity clusters
        m = (entity_masks == 0).float() * (-1e30)
        mention_spans_pool = mention_clusters + m
        entity_reprs = mention_spans_pool.max(dim=2)[0]
        entity_reprs = self.dropout(entity_reprs)

        return entity_reprs
    def forward(self, entity_reprs, rel_entity_types, pairs):
        rel_entity_types = self.entity_embeddings(rel_entity_types)

        batch_size = pairs.shape[0]

        entity_pairs = util.batch_index(entity_reprs, pairs)
        entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)

        rel_entity_types = rel_entity_types.view(rel_entity_types.shape[0], rel_entity_types.shape[1], -1)
        entity_pair_repr = self.entity_pair_linear(torch.cat([entity_pairs, rel_entity_types], dim=2))
        entity_pair_repr = self.dropout(torch.relu(entity_pair_repr))

        return entity_pair_repr
    def _create_mention_pair_representations(self, entity_pair_reprs,
                                             chunk_rel_mention_pair_ep,
                                             rel_mention_pairs, rel_ctx_masks,
                                             rel_token_distances,
                                             rel_sentence_distances,
                                             mention_reprs, h):
        rel_token_distances = self.token_distance_embeddings(
            rel_token_distances)
        rel_sentence_distances = self.sentence_distance_embeddings(
            rel_sentence_distances)

        rel_mention_pair_reprs = util.batch_index(mention_reprs,
                                                  rel_mention_pairs)

        s = rel_mention_pair_reprs.shape
        rel_mention_pair_reprs = rel_mention_pair_reprs.view(s[0], s[1], -1)

        # ctx max pooling
        m = ((rel_ctx_masks == 0).float() * (-1e30)).unsqueeze(-1)
        rel_ctx = m + h
        # max pooling
        rel_ctx, rel_ctx_indices = rel_ctx.max(dim=2)

        # set the context vector of neighboring or adjacent spans to zero
        rel_ctx[rel_ctx_masks.bool().any(-1) == 0] = 0

        entity_pair_reprs = util.batch_index(entity_pair_reprs,
                                             chunk_rel_mention_pair_ep)

        local_repr = torch.cat([
            rel_ctx, rel_mention_pair_reprs, entity_pair_reprs,
            rel_token_distances, rel_sentence_distances
        ],
                               dim=2)

        local_repr = self.dropout(self.pair_linear(local_repr))

        return local_repr
Example #7
0
    def _forward_inference(self,
                           encodings: torch.tensor,
                           context_masks: torch.tensor,
                           mention_masks: torch.tensor,
                           mention_sizes: torch.tensor,
                           mention_spans: torch.tensor,
                           mention_sample_masks: torch.tensor,
                           max_spans=None,
                           max_coref_pairs=None,
                           *args,
                           **kwargs):
        res = self._forward_inference_common(encodings,
                                             context_masks,
                                             mention_masks,
                                             mention_sizes,
                                             mention_spans,
                                             mention_sample_masks,
                                             max_coref_pairs=max_coref_pairs,
                                             max_spans=max_spans)
        (h, mention_reprs, entity_reprs, clusters, entity_sample_masks,
         mention_pair_sample_masks, clusters_sample_masks, mention_clf,
         entity_clf, coref_clf) = res

        # create entity pairs
        rel_entity_pairs, rel_sample_masks = misc.create_rel_global_entity_pairs(
            entity_reprs, entity_sample_masks)
        rel_sample_masks = rel_sample_masks.float()

        # create entity pair representations
        entity_types = entity_clf.argmax(dim=-1)
        rel_entity_types = util.batch_index(entity_types, rel_entity_pairs)
        entity_pair_reprs = self.entity_pair_representation(
            entity_reprs, rel_entity_types, rel_entity_pairs)

        # classify relations
        rel_clf = self.relation_classification(entity_pair_reprs)

        # thresholding and masking
        mention_clf, coref_clf, entity_clf, rel_clf = self._apply_thresholds(
            mention_clf, coref_clf, entity_clf, rel_clf, mention_sample_masks,
            mention_pair_sample_masks, entity_sample_masks, rel_sample_masks)

        return dict(mention_clf=mention_clf,
                    coref_clf=coref_clf,
                    entity_clf=entity_clf,
                    rel_clf=rel_clf,
                    clusters=clusters,
                    clusters_sample_masks=clusters_sample_masks,
                    rel_entity_pairs=rel_entity_pairs)
Example #8
0
    def _classify_corefs(self, mention_reprs, coref_mention_pairs, coref_eds):
        batch_size = coref_mention_pairs.shape[0]

        # get pairs of entity mention representations
        mention_pairs1 = util.batch_index(mention_reprs, coref_mention_pairs)
        mention_pairs = mention_pairs1.view(batch_size,
                                            mention_pairs1.shape[1], -1)

        coref_repr = torch.cat([mention_pairs, coref_eds], dim=2)
        coref_repr = torch.relu(self.coref_linear(coref_repr))
        coref_repr = self.dropout(coref_repr)

        # classify coref candidates
        chunk_coref_logits = self.coref_classifier(coref_repr)
        chunk_coref_logits = chunk_coref_logits.squeeze(dim=-1)
        return chunk_coref_logits
    def _classify_relations(self, rel_mention_pair_reprs, rel_entity_pair_mp,
                            rel_pair_masks, rel_entity_types):
        local_repr = util.batch_index(rel_mention_pair_reprs,
                                      rel_entity_pair_mp)

        local_repr += (rel_pair_masks.unsqueeze(-1) == 0).float() * (-1e30)
        local_repr = local_repr.max(dim=2)[0]

        rel_entity_types = self.entity_type_embeddings(rel_entity_types)
        rel_entity_types = rel_entity_types.view(rel_entity_types.shape[0],
                                                 rel_entity_types.shape[1], -1)

        rel_repr = torch.cat([local_repr, rel_entity_types], dim=2)
        rel_repr = self.dropout(torch.relu(self.rel_linear(rel_repr)))

        # classify relation candidates
        rel_logits = self.rel_classifier(rel_repr)
        rel_logits = rel_logits.squeeze(dim=-1)

        return rel_logits
Example #10
0
    def _forward_inference(self,
                           encodings: torch.tensor,
                           context_masks: torch.tensor,
                           mention_masks: torch.tensor,
                           mention_sizes: torch.tensor,
                           mention_spans: torch.tensor,
                           mention_sample_masks: torch.tensor,
                           mention_sent_indices: torch.tensor,
                           mention_orig_spans: torch.tensor,
                           max_spans: bool = None,
                           max_coref_pairs: bool = None,
                           max_rel_pairs: bool = None,
                           *args,
                           **kwargs):
        res = self._forward_inference_common(encodings,
                                             context_masks,
                                             mention_masks,
                                             mention_sizes,
                                             mention_spans,
                                             mention_sample_masks,
                                             max_spans=max_spans,
                                             max_coref_pairs=max_coref_pairs)
        (h, mention_reprs, entity_reprs, clusters, entity_sample_masks,
         mention_pair_sample_masks, clusters_sample_masks, mention_clf,
         entity_clf, coref_clf) = res

        # create entity pairs
        (rel_entity_pair_mp, rel_mention_pair_ep, rel_entity_pairs,
         rel_mention_pairs, rel_ctx_masks, rel_token_distances,
         rel_sentence_distances,
         rel_mention_pair_masks) = misc.create_local_entity_pairs(
             clusters, clusters_sample_masks, mention_spans,
             mention_sent_indices, mention_orig_spans, context_masks.shape[-1])
        rel_sample_masks = rel_mention_pair_masks.any(dim=-1)

        # create entity pair representations
        entity_pair_reprs = self.entity_pair_representation(
            entity_reprs, rel_entity_pairs)

        # classify relations
        entity_types = entity_clf.argmax(dim=-1)
        rel_entity_types = util.batch_index(entity_types, rel_entity_pairs)
        rel_clf = self.relation_classification(entity_pair_reprs,
                                               h,
                                               mention_reprs,
                                               rel_entity_pair_mp,
                                               rel_mention_pair_ep,
                                               rel_mention_pairs,
                                               rel_ctx_masks,
                                               rel_mention_pair_masks,
                                               rel_token_distances,
                                               rel_sentence_distances,
                                               rel_entity_types,
                                               max_pairs=max_rel_pairs)

        # thresholding and masking
        mention_clf, coref_clf, entity_clf, rel_clf = self._apply_thresholds(
            mention_clf, coref_clf, entity_clf, rel_clf, mention_sample_masks,
            mention_pair_sample_masks, entity_sample_masks, rel_sample_masks)

        return dict(mention_clf=mention_clf,
                    coref_clf=coref_clf,
                    entity_clf=entity_clf,
                    rel_clf=rel_clf,
                    clusters=clusters,
                    clusters_sample_masks=clusters_sample_masks,
                    rel_entity_pairs=rel_entity_pairs)