Beispiel #1
0
    def __init__(self, conf: Dict, word_batch: WordBatch,
                 char_batch: CharacterBatch, n_class: int):
        super(Model, self).__init__(conf, word_batch, char_batch)
        self.dropout = torch.nn.Dropout(conf['dropout'])

        c = conf['classifier']
        classify_layer_name = c['name'].lower()
        if classify_layer_name == 'softmax':
            self.classify_layer = SoftmaxLoss(self.output_dim, n_class)
        elif classify_layer_name == 'cnn_softmax':
            raise NotImplementedError('cnn_softmax is not ready.')
        elif classify_layer_name == 'sampled_softmax':
            sparse = conf['optimizer']['type'].lower() in ('sgd', 'adam',
                                                           'dense_sparse_adam')
            self.classify_layer = SampledSoftmaxLoss(n_class,
                                                     self.output_dim,
                                                     c['n_samples'],
                                                     sparse=sparse)
        elif classify_layer_name == 'window_sampled_softmax':
            sparse = conf['optimizer']['type'].lower() in ('sgd', 'adam',
                                                           'dense_sparse_adam')
            self.classify_layer = WindowSampledSoftmaxLoss(n_class,
                                                           self.output_dim,
                                                           c['n_samples'],
                                                           sparse=sparse)
        else:
            raise ValueError(
                'Unknown classify_layer: {}'.format(classify_layer_name))
Beispiel #2
0
    def test_sampled_softmax_can_run(self):
        softmax = SampledSoftmaxLoss(num_words=1000, embedding_dim=12, num_samples=50)

        # sequence_length, embedding_dim
        embedding = torch.rand(100, 12)
        targets = torch.randint(0, 1000, (100,)).long()

        _ = softmax(embedding, targets)
Beispiel #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 seq_embedder: TextFieldEmbedder,
                 abstract_text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 calculate_recall: bool = False,
                 use_abstracts: bool = True,
                 use_node_vectors: bool = True,
                 num_samples: int = None,
                 dropout: float = None) -> None:
        super().__init__(vocab)

        self._abstract_text_field_embedder = abstract_text_field_embedder

        self._use_abstracts = use_abstracts

        self._use_node_vectors = use_node_vectors

        self._seq_embedder = seq_embedder

        self._calculate_recall = calculate_recall
        # lstm encoder uses PytorchSeq2SeqWrapper for pytorch lstm
        self._contextualizer = contextualizer

        self._forward_dim = contextualizer.get_output_dim()

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=False)
        else:
            self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                              embedding_dim=self._forward_dim)

        self._n_list = range(1, 50)
        self._recall_at_n = {}
        for n in self._n_list:
            self._recall_at_n[n] = RecallAtN(n)
        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._softmax_loss = SoftmaxLoss(
                num_words=vocab.get_vocab_size("transactions"),
                embedding_dim=self._forward_dim,
            )

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)
    def test_sampled_equals_unsampled_during_eval(self):
        sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=40)
        unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12)

        sampled_softmax.eval()
        unsampled_softmax.eval()

        # set weights equal, use transpose because opposite shapes
        sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t()
        sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b

        # sequence_length, embedding_dim
        embedding = torch.rand(100, 12)
        targets = torch.randint(0, 1000, (100,)).long()

        full_loss = unsampled_softmax(embedding, targets).item()
        sampled_loss = sampled_softmax(embedding, targets).item()

        # Should be really close
        np.testing.assert_almost_equal(sampled_loss, full_loss)
Beispiel #6
0
    def test_sampled_equals_unsampled_when_biased_against_non_sampled_positions(
            self):
        sampled_softmax = SampledSoftmaxLoss(num_words=10000,
                                             embedding_dim=12,
                                             num_samples=10)
        unsampled_softmax = SoftmaxLoss(num_words=10000, embedding_dim=12)

        # fake out choice function
        FAKE_SAMPLES = [100, 200, 300, 400, 500, 600, 700, 800, 900, 9999]

        def fake_choice(num_words: int,
                        num_samples: int) -> Tuple[np.ndarray, int]:
            assert (num_words, num_samples) == (10000, 10)
            return np.array(FAKE_SAMPLES), 12

        sampled_softmax.choice_func = fake_choice

        # bias out the unsampled terms:
        for i in range(10000):
            if i not in FAKE_SAMPLES:
                unsampled_softmax.softmax_b[i] = -10000

        # set weights equal, use transpose because opposite shapes
        sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t()
        sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b

        sampled_softmax.train()
        unsampled_softmax.train()

        # sequence_length, embedding_dim
        embedding = torch.rand(100, 12)
        targets = torch.randint(0, 1000, (100, )).long()

        full_loss = unsampled_softmax(embedding, targets).item()
        sampled_loss = sampled_softmax(embedding, targets).item()

        # Should be close

        pct_error = (sampled_loss - full_loss) / full_loss
        assert abs(pct_error) < 0.001
Beispiel #7
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 dropout: float = None,
                 loss_scale: Union[float, str] = 1.0,
                 num_samples: int = None,
                 sparse_embeddings: bool = False,
                 bidirectional: bool = False,
                 initializer: InitializerApplicator = None) -> None:
        super().__init__(vocab)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer = contextualizer
        self._bidirectional = bidirectional

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings)
        else:
            self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                              embedding_dim=self._forward_dim)

        # TODO(brendanr): Output perplexity here. e^loss
        self.register_buffer('_last_average_loss', torch.zeros(1))

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self._loss_scale = loss_scale
        if initializer is not None:
            initializer(self)
Beispiel #8
0
    def test_sampled_almost_equals_unsampled_when_num_samples_is_almost_all(self):
        sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=9999)
        unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12)

        # sequence_length, embedding_dim
        embedding = torch.rand(100, 12)
        targets = torch.randint(0, 1000, (100,)).long()

        full_loss = unsampled_softmax(embedding, targets).item()
        sampled_loss = sampled_softmax(embedding, targets).item()

        # Should be really close
        pct_error = (sampled_loss - full_loss) / full_loss
        assert abs(pct_error) < 0.02
Beispiel #9
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 forward_segmental_contextualizer: Seq2SeqEncoder,
                 backward_segmental_contextualizer: Seq2SeqEncoder,
                 label_feature_dim: int,
                 softmax_projection_dim: int,
                 label_namespace: str = "labels",
                 dropout: float = None,
                 num_samples: int = None,
                 sparse_embeddings: bool = False,
                 bidirectional: bool = True,
                 initializer: InitializerApplicator = None) -> None:
        super().__init__(vocab=vocab,
                         text_field_embedder=text_field_embedder,
                         contextualizer=contextualizer,
                         dropout=dropout,
                         num_samples=num_samples,
                         sparse_embeddings=sparse_embeddings,
                         bidirectional=bidirectional,
                         initializer=initializer)
        self._forward_segmental_contextualizer = forward_segmental_contextualizer
        self._backward_segmental_contextualizer = backward_segmental_contextualizer

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=softmax_projection_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings)
        else:
            self._softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=softmax_projection_dim)

        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        self.label_feature_embedding = Embedding(self.num_classes,
                                                 label_feature_dim)

        base_dim = contextualizer.get_output_dim() // 2
        seg_dim = base_dim + label_feature_dim
        self._forward_dim = softmax_projection_dim

        self.pre_segmental_layer = TimeDistributed(
            Linear(seg_dim, softmax_projection_dim))
        self.projection_layer = TimeDistributed(
            Linear(base_dim * 2, softmax_projection_dim))
Beispiel #10
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        hyperbolic_embedder: TextFieldEmbedder,
        hyperbolic_encoder: Seq2VecEncoder,
        hyperbolic_weight: float,
        is_baseline: bool = False,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
    ) -> None:
        super().__init__(
            vocab,
            text_field_embedder,
            contextualizer,
            dropout,
            num_samples,
            sparse_embeddings,
            bidirectional,
            initializer
        )
        # reinitialize self._softmax_loss to change default namespace 'token'
        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(namespace='euclidean'),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._softmax_loss = SoftmaxLoss(
                num_words=vocab.get_vocab_size(namespace='euclidean'), 
                embedding_dim=self._forward_dim
            )

        # initialize hyperbolic components
        self._hyperbolic_embedder = hyperbolic_embedder
        self._hyperbolic_encoder = hyperbolic_encoder
        self._hyperbolic_encoding_loss = HyperbolicL1()
        self._hyperbolic_weight = hyperbolic_weight

        # vanila language mode
        self.is_baseline = is_baseline
    def test_sampled_softmax_has_greater_loss_in_train_mode(self):
        sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=10)

        # sequence_length, embedding_dim
        embedding = torch.rand(100, 12)
        targets = torch.randint(0, 1000, (100,)).long()

        sampled_softmax.train()
        train_loss = sampled_softmax(embedding, targets).item()

        sampled_softmax.eval()
        eval_loss = sampled_softmax(embedding, targets).item()

        assert eval_loss > train_loss
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 layer_norm: Optional[MaskedLayerNorm] = None,
                 dropout: float = None,
                 loss_scale: Union[float, str] = 1.0,
                 remove_bos_eos: bool = True,
                 num_samples: int = None,
                 sparse_embeddings: bool = False) -> None:
        super().__init__(vocab)
        self._text_field_embedder = text_field_embedder
        self._layer_norm = layer_norm or (lambda x: x)

        if not contextualizer.is_bidirectional():
            raise ConfigurationError("contextualizer must be bidirectional")

        self._contextualizer = contextualizer
        # The dimension for making predictions just in the forward
        # (or backward) direction.
        self._forward_dim = contextualizer.get_output_dim() // 2

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings)
        else:
            self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                              embedding_dim=self._forward_dim)

        self.register_buffer('_last_average_loss', torch.zeros(1))

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self._loss_scale = loss_scale
        self._remove_bos_eos = remove_bos_eos
Beispiel #13
0
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 dropout: float = None,
                 tie_embeddings: bool = True,
                 num_samples: int = None,
                 use_variational_dropout: bool = False):
        super().__init__(vocab)

        self._embedder = embedder
        self._contextualizer = contextualizer
        self._context_dim = contextualizer.get_output_dim()

        if use_variational_dropout:
            self._dropout = InputVariationalDropout(
                dropout) if dropout else lambda x: x
        else:
            self._dropout = Dropout(dropout) if dropout else lambda x: x

        vocab_size = self.vocab.get_vocab_size()
        padding_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN)
        if num_samples:
            self._softmax_loss = SampledSoftmaxLoss(vocab_size,
                                                    self._context_dim,
                                                    num_samples)
        else:
            self._softmax_loss = SoftmaxLoss(vocab_size, self._context_dim,
                                             padding_index)

        self._tie_embeddings = tie_embeddings
        if self._tie_embeddings:
            embedder_children = dict(self._embedder.named_children())
            word_embedder = embedder_children["token_embedder_tokens"]
            assert self._softmax_loss.softmax_w.size(
            ) == word_embedder.weight.size()
            self._softmax_loss.softmax_w = word_embedder.weight
Beispiel #14
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        contextualizer: Seq2SeqEncoder,
        aux_contextualizer: Seq2SeqEncoder,
        dropout: float = None,
        num_samples: int = None,
        sparse_embeddings: bool = False,
        bidirectional: bool = False,
        initializer: InitializerApplicator = None,
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder

        if contextualizer.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        self._contextualizer_lang1 = aux_contextualizer
        self._contextualizer_lang2 = copy.deepcopy(aux_contextualizer)
        self._contextualizer = contextualizer

        self._bidirectional = bidirectional
        self._bidirectional_aux = aux_contextualizer.is_bidirectional()

        # The dimension for making predictions just in the forward
        # (or backward) direction.
        # main contextualizer forward dim
        if self._bidirectional:
            self._forward_dim = contextualizer.get_output_dim() // 2
        else:
            self._forward_dim = contextualizer.get_output_dim()

        # aux contextualizer forward dim
        if self._bidirectional_aux:
            self._forward_dim_aux = aux_contextualizer.get_output_dim() // 2
        else:
            self._forward_dim_aux = aux_contextualizer.get_output_dim()

        # TODO(joelgrus): more sampled softmax configuration options, as needed.
        if num_samples is not None:
            self._lang1_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._lang2_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
            self._cm_softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=sparse_embeddings,
            )
        else:
            self._lang1_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._lang2_softmax_loss = _SoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim_aux)
            self._cm_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                         embedding_dim=self._forward_dim)

        # This buffer is now unused and exists only for backwards compatibility reasons.
        self.register_buffer("_last_average_loss", torch.zeros(1))

        self._lang1_perplexity = Perplexity()
        self._lang2_perplexity = Perplexity()
        self._cm_perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        if initializer is not None:
            initializer(self)
Beispiel #15
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 spellchecker_namespace: str = 'target_tokens',
                 punct_namespace: str = 'punct_labels',
                 feedforward: Optional[FeedForward] = None,
                 punct_hidden: int = 256,
                 embedding_dropout: Optional[float] = None,
                 encoded_dropout: Optional[float] = None,
                 punct_dropout: Optional[float] = None,
                 punct_weight: Optional[Dict[str, float]] = None,
                 num_samples: Optional[int] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = punct_namespace
        self.text_field_embedder = text_field_embedder
        self.token_vocab_size = self.vocab.get_vocab_size(
            spellchecker_namespace)
        self.punct_vocab_size = self.vocab.get_vocab_size(punct_namespace)
        self.encoder = encoder
        self.embedding_dropout = Dropout(
            embedding_dropout) if embedding_dropout is not None else None
        self.encoded_dropout = Dropout(
            encoded_dropout) if encoded_dropout is not None else None
        self.feedforward = feedforward

        if feedforward is not None:
            self.output_dim = feedforward.get_output_dim()
        else:
            self.output_dim = self.encoder.get_output_dim()

        if punct_dropout is not None:
            self.punct_projection = Sequential(
                Linear(self.output_dim, punct_hidden), Dropout(punct_dropout),
                Linear(punct_hidden, self.punct_vocab_size))
        else:
            self.punct_projection = Sequential(
                Linear(self.output_dim, punct_hidden),
                Linear(punct_hidden, self.punct_vocab_size))

        self.losses = {
            'spellchecker':
            SoftmaxLoss(num_words=self.token_vocab_size,
                        embedding_dim=self.output_dim +
                        1) if num_samples is None else SampledSoftmaxLoss(
                            num_words=self.token_vocab_size,
                            embedding_dim=self.output_dim + 1,
                            num_samples=num_samples),
            'punct':
            CrossEntropyLoss(weight=self.__get_weight_tensor(punct_weight),
                             reduction='sum',
                             ignore_index=-1)
        }
        self.add_module('spellchecker_loss', self.losses['spellchecker'])
        self.add_module('punct_loss', self.losses['punct'])

        self.metrics = {'punct_accuracy': CategoricalAccuracy()}
        self.metrics.update({
            f'f1_score_{name}': F1Measure(
                self.vocab.get_token_index(name, namespace=punct_namespace))
            for name in self.vocab.get_token_to_index_vocabulary(
                namespace=punct_namespace)
        })

        initializer(self)
Beispiel #16
0
class AclSequenceModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 seq_embedder: TextFieldEmbedder,
                 abstract_text_field_embedder: TextFieldEmbedder,
                 contextualizer: Seq2SeqEncoder,
                 calculate_recall: bool = False,
                 use_abstracts: bool = True,
                 use_node_vectors: bool = True,
                 num_samples: int = None,
                 dropout: float = None) -> None:
        super().__init__(vocab)

        self._abstract_text_field_embedder = abstract_text_field_embedder

        self._use_abstracts = use_abstracts

        self._use_node_vectors = use_node_vectors

        self._seq_embedder = seq_embedder

        self._calculate_recall = calculate_recall
        # lstm encoder uses PytorchSeq2SeqWrapper for pytorch lstm
        self._contextualizer = contextualizer

        self._forward_dim = contextualizer.get_output_dim()

        if num_samples is not None:
            self._softmax_loss = SampledSoftmaxLoss(
                num_words=vocab.get_vocab_size(),
                embedding_dim=self._forward_dim,
                num_samples=num_samples,
                sparse=False)
        else:
            self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
                                              embedding_dim=self._forward_dim)

        self._n_list = range(1, 50)
        self._recall_at_n = {}
        for n in self._n_list:
            self._recall_at_n[n] = RecallAtN(n)
        self._perplexity = Perplexity()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

    def _compute_loss(self, lm_embeddings: torch.Tensor,
                      targets: torch.Tensor) -> torch.Tensor:

        # Because the targets are offset by 1, we re-mask to
        # remove the final 0 in the targets
        mask = targets > 0
        non_masked_targets = targets.masked_select(mask) - 1
        non_masked_embeddings = lm_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)

        return self._softmax_loss(non_masked_embeddings, non_masked_targets)

    def _compute_probs(self, lm_embeddings: torch.Tensor,
                       targets: torch.Tensor) -> torch.Tensor:

        # Because the targets are offset by 1, we re-mask to
        # remove the final 0 in the targets
        mask = targets > 0
        non_masked_targets = targets.masked_select(mask) - 1
        non_masked_embeddings = lm_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)

        return self._softmax_loss.probs(non_masked_embeddings)

    def num_layers(self) -> int:
        return self_contextualizer.num_layers + 1

    def forward(
            self, abstracts: Dict[str, torch.LongTensor],
            paper_ids: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
        """
        Computes the loss from the batch.
        """

        if self._use_abstracts and self._use_node_vectors:
            embeddings = torch.cat([
                self._abstract_text_field_embedder(abstracts)[:, :, 0, :],
                self._seq_embedder(paper_ids)
            ],
                                   dim=-1)
            mask = get_text_field_mask(abstracts, num_wrapping_dims=1)
            mask = mask.sum(dim=-1) > 0
        elif self._use_abstracts:
            embeddings = self._abstract_text_field_embedder(abstracts)[:, :,
                                                                       0, :]
            mask = get_text_field_mask(abstracts, num_wrapping_dims=1)
            mask = mask.sum(dim=-1) > 0
        elif self._use_node_vectors:
            embeddings = self._seq_embedder(paper_ids)
            mask = get_text_field_mask(paper_ids)
        else:
            # When use_node_vectors is false, the embedder should be configured
            # to initialize random embeddings. The redundant else condition
            # makes this difference in behavior a little more explicit, even
            # though the content of the block is identical.
            embeddings = self._seq_embedder(paper_ids)
            mask = get_text_field_mask(paper_ids)

        contextual_embeddings: Union[
            torch.Tensor,
            List[torch.Tensor]] = self._contextualizer(embeddings, mask.long())

        contextual_embeddings_with_dropout = self._dropout(
            contextual_embeddings)

        return_dict = {}

        assert isinstance(contextual_embeddings_with_dropout, torch.Tensor)

        # targets is like paper ids, but offset forward by 1 in the second
        # dimension.
        targets = torch.zeros_like(paper_ids['tokens'])
        targets[:, 0:targets.size()[1] - 1] = paper_ids['tokens'][:, 1:]

        loss = self._compute_loss(contextual_embeddings_with_dropout, targets)

        num_targets = torch.sum((targets > 0).long())
        if num_targets > 0:
            average_loss = loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0).to(targets.device)

        perplexity = self._perplexity(average_loss)

        if self._calculate_recall:
            top_k = self.get_recall_at_n(contextual_embeddings, targets)
            return_dict.update({'top_k': top_k})

        if num_targets > 0:
            return_dict.update({
                'loss': average_loss,
                'batch_weight': num_targets.float()
            })
        else:
            return_dict.update({'loss': average_loss})

        return_dict.update({
            'lm_embeddings': contextual_embeddings,
            'lm_targets': targets,
            'noncontextual_embeddings': embeddings,
        })

        return return_dict

    def get_recall_at_n(self, embeddings, targets):
        top_n = []
        # iterate over batches:
        for embeddings, targets in zip(embeddings.detach(), targets.detach()):
            # (sequence_length, #targets)
            probs = self._compute_probs(embeddings, targets)
            top_probs, top_indices = probs.topk(k=max(self._n_list), dim=-1)
            top_ids = [[
                self.vocab.get_token_from_index(int(i)) for i in top_n
            ] for top_n in top_indices]
            top_n.append(top_ids)

            mask = targets > 0
            non_masked_targets = targets.masked_select(mask) - 1
            for n in self._n_list:
                self._recall_at_n[n](non_masked_targets, top_indices)
        return top_n

    def get_metrics(self, reset: bool = False):
        metrics = {"perplexity": self._perplexity.get_metric(reset=reset)}
        if self._calculate_recall:
            for n in self._n_list:
                recall = self._recall_at_n[n].get_metric(reset=reset)
                metrics.update({"recall_at_{}".format(n): recall})
        return metrics