def test_span_scorer_works_for_completely_masked_rows(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) # type: ignore spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 mask[2, :] = 0 # fully masked last batch element. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( spans, mask, 2) # We can't check the last row here, because it's completely masked. # Instead we'll check that the scores for these elements are -inf. numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1], [1, 2]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1], [1, 1], [0, 0]])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements, with # masked elements equal to -inf. correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy() correct_scores[2, :] = float(u"-inf") numpy.testing.assert_array_equal(correct_scores, pruned_scores.data.numpy())
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( spans, mask, 2) # Second element in the batch would have indices 2, 3, but # 3 and 0 are masked, so instead it has 1, 2. numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements. numpy.testing.assert_array_equal( correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(), pruned_scores.data.numpy())
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, context_layer: Seq2SeqEncoder, mention_feedforward: FeedForward, antecedent_feedforward: FeedForward, feature_size: int, max_span_width: int, spans_per_word: float, max_antecedents: int, lexical_dropout: float = 0.2, context_layer_back: Seq2SeqEncoder = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(CoreferenceResolver, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._context_layer = context_layer self._context_layer_back = context_layer_back self._antecedent_feedforward = TimeDistributed(antecedent_feedforward) feedforward_scorer = torch.nn.Sequential( TimeDistributed(mention_feedforward), TimeDistributed( torch.nn.Linear(mention_feedforward.get_output_dim(), 1))) self._mention_pruner = SpanPruner(feedforward_scorer) self._antecedent_scorer = TimeDistributed( torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)) # TODO check the output dim when two context layers are passed through self._endpoint_span_extractor = EndpointSpanExtractor( context_layer.get_output_dim(), combination="x,y", num_width_embeddings=max_span_width, span_width_embedding_dim=feature_size, bucket_widths=False) self._attentive_span_extractor = SelfAttentiveSpanExtractor( input_dim=text_field_embedder.get_output_dim()) # 10 possible distance buckets. self._num_distance_buckets = 10 self._distance_embedding = Embedding(self._num_distance_buckets, feature_size) self._speaker_embedding = Embedding(2, feature_size) self.genres = { g: i for i, g in enumerate(['bc', 'bn', 'mz', 'nw', 'pt', 'tc', 'wb']) } self._genre_embedding = Embedding(len(self.genres), feature_size) self._max_span_width = max_span_width self._spans_per_word = spans_per_word self._max_antecedents = max_antecedents self._mention_recall = MentionRecall() self._conll_coref_scores = ConllCorefScores() if lexical_dropout > 0: self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout) else: self._lexical_dropout = lambda x: x self._feature_dropout = torch.nn.Dropout(0.2) initializer(self)
def test_span_scorer_raises_with_incorrect_scorer_spec(self): # Mis-configured scorer - doesn't produce a tensor with 1 as it's final dimension. scorer = lambda tensor: tensor.sum(-1) pruner = SpanPruner(scorer=scorer) # type: ignore spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) mask = torch.ones([3, 4]) with pytest.raises(ValueError): _ = pruner(spans, mask, 2)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, context_layer: Seq2SeqEncoder, mention_feedforward: FeedForward, antecedent_feedforward: FeedForward, feature_size: int, max_span_width: int, spans_per_word: float, max_antecedents: int, lexical_dropout: float = 0.2, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, coarse_to_fine_pruning: bool = False) -> None: super(CoreferenceResolver, self).__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder self._context_layer = context_layer self._antecedent_feedforward = TimeDistributed(antecedent_feedforward) feedforward_scorer = torch.nn.Sequential( TimeDistributed(mention_feedforward), TimeDistributed( torch.nn.Linear(mention_feedforward.get_output_dim(), 1))) self._mention_pruner = SpanPruner(feedforward_scorer) self._antecedent_scorer = TimeDistributed( torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)) # do coarse to fine pruning self._do_coarse_to_fine_prune = coarse_to_fine_pruning self._endpoint_span_extractor = EndpointSpanExtractor( context_layer.get_output_dim(), combination="x,y", num_width_embeddings=max_span_width, span_width_embedding_dim=feature_size, bucket_widths=False) self._attentive_span_extractor = SelfAttentiveSpanExtractor( input_dim=text_field_embedder.get_output_dim()) # 10 possible distance buckets. self._num_distance_buckets = 10 self._distance_embedding = Embedding(self._num_distance_buckets, feature_size) self._max_span_width = max_span_width self._spans_per_word = spans_per_word self._max_antecedents = max_antecedents self._mention_recall = MentionRecall() self._conll_coref_scores = ConllCorefScores() if lexical_dropout > 0: self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout) else: self._lexical_dropout = lambda x: x initializer(self)
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, max_decoding_steps: int, spans_per_word: float, target_namespace: str = "tokens", target_embedding_dim: int = None, attention_function: SimilarityFunction = None, scheduled_sampling_ratio: float = 0.0, spans_extractor: SpanExtractor = None, spans_scorer_feedforward: FeedForward = None) -> None: super(SpanAe, self).__init__(vocab) self._source_embedder = source_embedder self._encoder = encoder self._max_decoding_steps = max_decoding_steps self._target_namespace = target_namespace self._attention_function = attention_function self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) num_classes = self.vocab.get_vocab_size(self._target_namespace) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with that of the final hidden states of the encoder. Also, if # we're using attention with ``DotProductSimilarity``, this is needed. self._decoder_output_dim = self._encoder.get_output_dim() + 1 target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim( ) self._target_embedder = Embedding(num_classes, target_embedding_dim) if self._attention_function: self._decoder_attention = Attention(self._attention_function) # The output of attention, a weighted average over encoder outputs, will be # concatenated to the input vector of the decoder at each time step. self._decoder_input_dim = self._encoder.get_output_dim( ) + target_embedding_dim else: self._decoder_input_dim = target_embedding_dim self._decoder_cell = LSTMCell(self._decoder_input_dim + 1, self._decoder_output_dim) self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) self._span_extractor = spans_extractor feedforward_scorer = torch.nn.Sequential( TimeDistributed(spans_scorer_feedforward), TimeDistributed( torch.nn.Linear(spans_scorer_feedforward.get_output_dim(), 1))) self._span_pruner = SpanPruner(feedforward_scorer) self._spans_per_word = spans_per_word