Example #1
0
    def test_pruner_selects_top_scored_items_and_respects_masking(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = Pruner(scorer=scorer)

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
            items, mask, 2)

        # Second element in the batch would have indices 2, 3, but
        # 3 and 0 are masked, so instead it has 1, 2.
        numpy.testing.assert_array_equal(pruned_indices.data.numpy(),
                                         numpy.array([[0, 1], [1, 2], [2, 3]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(),
                                         numpy.ones([3, 2]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(items, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(
            correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
            pruned_scores.data.numpy())
Example #2
0
    def test_scorer_raises_with_incorrect_scorer_spec(self):
        # Mis-configured scorer - doesn't produce a tensor with 1 as it's final dimension.
        scorer = lambda tensor: tensor.sum(-1)
        pruner = Pruner(scorer=scorer)  # type: ignore
        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        mask = torch.ones([3, 4])

        with pytest.raises(ValueError):
            _ = pruner(items, mask, 2)
Example #3
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        mention_feedforward: FeedForward,
        antecedent_feedforward: FeedForward,
        feature_size: int,
        max_span_width: int,
        spans_per_word: float,
        max_antecedents: int,
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(mention_feedforward),
            TimeDistributed(torch.nn.Linear(mention_feedforward.get_output_dim(), 1)),
        )
        self._mention_pruner = Pruner(feedforward_scorer)
        self._antecedent_scorer = TimeDistributed(
            torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
        )

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False,
        )
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=text_field_embedder.get_output_dim()
        )

        # 10 possible distance buckets.
        self._num_distance_buckets = 10
        self._distance_embedding = Embedding(self._num_distance_buckets, feature_size)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._max_antecedents = max_antecedents

        self._mention_recall = MentionRecall()
        self._conll_coref_scores = ConllCorefScores()
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        initializer(self)
Example #4
0
    def __init__(self,
                 input_dim: int,
                 extra_input_dim: int = 0,
                 span_hidden_dim: int = 100,
                 span_ffnn: FeedForward = None,
                 classifier: SetClassifier = SetBinaryClassifier(),
                 span_decoding_threshold: float = 0.05,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None):
        super(SpanSelector, self).__init__()

        self._input_dim = input_dim
        self._extra_input_dim = extra_input_dim
        self._span_hidden_dim = span_hidden_dim
        self._span_ffnn = span_ffnn
        self._classifier = classifier
        self._span_decoding_threshold = span_decoding_threshold

        self._span_hidden = SpanRepAssembly(self._input_dim, self._input_dim, self._span_hidden_dim)
        if self._span_ffnn is not None:
            if self._span_ffnn.get_input_dim() != self._span_hidden_dim:
                raise ConfigurationError(
                    "Span hidden dim %s must match span classifier FFNN input dim %s" % (
                        self._span_hidden_dim, self._span_ffnn.get_input_dim()
                    )
                )
            self._span_scorer = TimeDistributed(
                torch.nn.Sequential(
                    ReLU(),
                    self._span_ffnn,
                    Linear(self._span_ffnn.get_output_dim(), 1)))
        else:
            self._span_scorer = TimeDistributed(
                torch.nn.Sequential(
                    ReLU(),
                    Linear(self._span_hidden_dim, 1)))
        self._span_pruner = Pruner(self._span_scorer)

        if self._extra_input_dim > 0:
            self._extra_input_lin = Linear(self._extra_input_dim, self._span_hidden_dim)
Example #5
0
    def test_pruner_works_for_row_with_no_items_requested(self):
        # Case where `num_items_to_keep` is a tensor rather than an int. Make sure it does the right
        # thing when no items are requested for one of the rows.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = Pruner(scorer=scorer)

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :3, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0

        num_items_to_keep = torch.tensor([3, 2, 0], dtype=torch.long)

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
            items, mask, num_items_to_keep)

        # First element just picks top three entries. Second would pick entries 2 and 3, but 0 and 3
        # are masked, so it takes 1 and 2 (repeating the second index). The third element is
        # entirely masked and just repeats the largest index with a top-3 score.
        numpy.testing.assert_array_equal(
            pruned_indices.data.numpy(),
            numpy.array([[0, 1, 2], [1, 2, 2], [3, 3, 3]]))
        numpy.testing.assert_array_equal(
            pruned_mask.data.numpy(),
            numpy.array([[1, 1, 1], [1, 1, 0], [0, 0, 0]]))

        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(items, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements.
        numpy.testing.assert_array_equal(
            correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(),
            pruned_scores.data.numpy())
Example #6
0
    def test_scorer_works_for_completely_masked_rows(self):
        # Really simple scorer - sum up the embedding_dim.
        scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1)
        pruner = Pruner(scorer=scorer)  # type: ignore

        items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0)
        items[0, :2, :] = 1
        items[1, 2:, :] = 1
        items[2, 2:, :] = 1

        mask = torch.ones([3, 4])
        mask[1, 0] = 0
        mask[1, 3] = 0
        mask[2, :] = 0  # fully masked last batch element.

        pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(
            items, mask, 2)

        # We can't check the last row here, because it's completely masked.
        # Instead we'll check that the scores for these elements are very small.
        numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(),
                                         numpy.array([[0, 1], [1, 2]]))
        numpy.testing.assert_array_equal(pruned_mask.data.numpy(),
                                         numpy.array([[1, 1], [1, 1], [0, 0]]))
        # embeddings should be the result of index_selecting the pruned_indices.
        correct_embeddings = batched_index_select(items, pruned_indices)
        numpy.testing.assert_array_equal(correct_embeddings.data.numpy(),
                                         pruned_embeddings.data.numpy())
        # scores should be the sum of the correct embedding elements, with masked elements very
        # small (but not -inf, because that can cause problems).  We'll test these two cases
        # separately.
        correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy()
        numpy.testing.assert_array_equal(correct_scores[:2],
                                         pruned_scores[:2].data.numpy())
        numpy.testing.assert_array_equal(pruned_scores[2] < -1e15, [[1], [1]])
        numpy.testing.assert_array_equal(pruned_scores[2] == float('-inf'),
                                         [[0], [0]])
Example #7
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 context_layer: Seq2SeqEncoder,
                 relex_feedforward: FeedForward,
                 antecedent_feedforward: FeedForward,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 relex_spans_per_word: float,
                 max_antecedents: int,
                 mention_feedforward: FeedForward,
                 coref_mention_feedforward: FeedForward = None,
                 relex_mention_feedforward: FeedForward = None,
                 symmetric_relations: bool = False,
                 lexical_dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 loss_coref_weight: float = 1,
                 loss_relex_weight: float = 1,
                 loss_ner_weight: float = 1,
                 preserve_metadata: List = None,
                 relex_namespace: str = 'relation_labels') -> None:
        # If separate coref mention and relex mention feedforward scorers
        # are not provided, share the one of NER module
        if coref_mention_feedforward is None:
            coref_mention_feedforward = mention_feedforward
        if relex_mention_feedforward is None:
            relex_mention_feedforward = mention_feedforward

        super().__init__(vocab, text_field_embedder, context_layer,
                         coref_mention_feedforward, antecedent_feedforward,
                         feature_size, max_span_width, spans_per_word,
                         max_antecedents, lexical_dropout, initializer,
                         regularizer)

        self._symmetric_relations = symmetric_relations
        self._relex_spans_per_word = relex_spans_per_word
        self._loss_coref_weight = loss_coref_weight
        self._loss_relex_weight = loss_relex_weight
        self._loss_ner_weight = loss_ner_weight
        self._preserve_metadata = preserve_metadata or ['id']
        self._relex_namespace = relex_namespace

        relex_labels = list(
            vocab.get_token_to_index_vocabulary(self._relex_namespace))
        self._relex_mention_recall = RelexMentionRecall()
        self._relex_precision_recall_fscore = PrecisionRecallFScore(
            labels=relex_labels)

        relex_mention_scorer = Sequential(
            TimeDistributed(relex_mention_feedforward),
            TimeDistributed(
                Projection(relex_mention_feedforward.get_output_dim())))
        self._relex_mention_pruner = MultiTimeDistributed(
            Pruner(relex_mention_scorer))

        self._ner_scorer = Sequential(
            TimeDistributed(mention_feedforward),
            TimeDistributed(
                Projection(mention_feedforward.get_output_dim(),
                           vocab.get_vocab_size('ner_labels'),
                           with_dummy=True)))

        self._relex_scorer = Sequential(
            TimeDistributed(relex_feedforward),
            TimeDistributed(
                Projection(relex_feedforward.get_output_dim(),
                           vocab.get_vocab_size(self._relex_namespace),
                           with_dummy=True)))
Example #8
0
    def __init__(
            self,
            input_dim: int,
            extra_input_dim: int = 0,
            span_hidden_dim: int = 100,
            span_ffnn: FeedForward = None,
            objective: str = "binary",
            gold_span_selection_policy: str = "union",
            pruning_ratio: float = 2.0,
            skip_metrics_during_training: bool = True,
            # metric: SpanMetric = SpanMetric(),
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None):
        super(PruningSpanSelector, self).__init__()

        self._input_dim = input_dim
        self._span_hidden_dim = span_hidden_dim
        self._extra_input_dim = extra_input_dim
        self._span_ffnn = span_ffnn
        self._pruning_ratio = pruning_ratio
        self._objective = objective
        self._gold_span_selection_policy = gold_span_selection_policy
        self._skip_metrics_during_training = skip_metrics_during_training

        if objective not in objective_values:
            raise ConfigurationError(
                "QA objective must be one of the following: " +
                str(qa_objective_values))

        if gold_span_selection_policy not in gold_span_selection_policy_values:
            raise ConfigurationError(
                "QA span selection policy must be one of the following: " +
                str(qa_objective_values))

        if objective == "multinomial" and gold_span_selection_policy == "weighted":
            raise ConfigurationError(
                "Cannot use weighted span selection policy with multinomial objective."
            )

        # self._metric = metric

        self._span_hidden = SpanRepAssembly(input_dim, input_dim,
                                            self._span_hidden_dim)

        if self._span_ffnn is not None:
            if self._span_ffnn.get_input_dim() != self._span_hidden_dim:
                raise ConfigurationError(
                    "Span hidden dim %s must match span classifier FFNN input dim %s"
                    % (self._span_hidden_dim, self._span_ffnn.get_input_dim()))
            self._span_scorer = TimeDistributed(
                torch.nn.Sequential(
                    ReLU(), self._span_ffnn,
                    Linear(self._span_ffnn.get_output_dim(), 1)))
        else:
            self._span_scorer = TimeDistributed(
                torch.nn.Sequential(ReLU(), Linear(self._span_hidden_dim, 1)))

        self._span_pruner = Pruner(self._span_scorer)

        if self._extra_input_dim > 0:
            self._extra_input_lin = Linear(self._extra_input_dim,
                                           self._span_hidden_dim)
Example #9
0
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_dim: int,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 lexical_dropout: float = 0.2,
                 mlp_dropout: float = 0.4,
                 embedder_type=None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SCIIE, self).__init__(vocab, regularizer)
        self.class_num = self.vocab.get_vocab_size('labels')
        word_embeddings = get_embeddings(embedder_type, self.vocab,
                                         embedding_dim, True)
        embedding_dim = word_embeddings.get_output_dim()
        self._text_field_embedder = word_embeddings

        context_layer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim,
                          feature_size,
                          batch_first=True,
                          bidirectional=True))
        self._context_layer = context_layer

        endpoint_span_extractor_input_dim = context_layer.get_output_dim()
        attentive_span_extractor_input_dim = word_embeddings.get_output_dim()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            endpoint_span_extractor_input_dim,
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=attentive_span_extractor_input_dim)

        # self._span_extractor = PoolingSpanExtractor(embedding_dim,
        #                                             num_width_embeddings=max_span_width,
        #                                             span_width_embedding_dim=feature_size,
        #                                             bucket_widths=False)

        entity_feedforward = FeedForward(
            self._endpoint_span_extractor.get_output_dim() +
            self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu,
            mlp_dropout)
        # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150,
        #                                  F.relu, mlp_dropout)

        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(), 1)))
        self._mention_pruner = Pruner(feedforward_scorer)

        self._entity_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(),
                                self.class_num - 1)))

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        self._metric_all = FBetaMeasure()
        self._metric_avg = NERF1Metric()