def __init__(self, conf: Dict, word_batch: WordBatch, char_batch: CharacterBatch, n_class: int): super(Model, self).__init__(conf, word_batch, char_batch) self.dropout = torch.nn.Dropout(conf['dropout']) c = conf['classifier'] classify_layer_name = c['name'].lower() if classify_layer_name == 'softmax': self.classify_layer = SoftmaxLoss(self.output_dim, n_class) elif classify_layer_name == 'cnn_softmax': raise NotImplementedError('cnn_softmax is not ready.') elif classify_layer_name == 'sampled_softmax': sparse = conf['optimizer']['type'].lower() in ('sgd', 'adam', 'dense_sparse_adam') self.classify_layer = SampledSoftmaxLoss(n_class, self.output_dim, c['n_samples'], sparse=sparse) elif classify_layer_name == 'window_sampled_softmax': sparse = conf['optimizer']['type'].lower() in ('sgd', 'adam', 'dense_sparse_adam') self.classify_layer = WindowSampledSoftmaxLoss(n_class, self.output_dim, c['n_samples'], sparse=sparse) else: raise ValueError( 'Unknown classify_layer: {}'.format(classify_layer_name))
def test_sampled_softmax_can_run(self): softmax = SampledSoftmaxLoss(num_words=1000, embedding_dim=12, num_samples=50) # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() _ = softmax(embedding, targets)
def __init__(self, vocab: Vocabulary, seq_embedder: TextFieldEmbedder, abstract_text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, calculate_recall: bool = False, use_abstracts: bool = True, use_node_vectors: bool = True, num_samples: int = None, dropout: float = None) -> None: super().__init__(vocab) self._abstract_text_field_embedder = abstract_text_field_embedder self._use_abstracts = use_abstracts self._use_node_vectors = use_node_vectors self._seq_embedder = seq_embedder self._calculate_recall = calculate_recall # lstm encoder uses PytorchSeq2SeqWrapper for pytorch lstm self._contextualizer = contextualizer self._forward_dim = contextualizer.get_output_dim() if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=False) else: self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim) self._n_list = range(1, 50) self._recall_at_n = {} for n in self._n_list: self._recall_at_n[n] = RecallAtN(n) self._perplexity = Perplexity() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, dropout: float = None, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = False, initializer: InitializerApplicator = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) self._text_field_embedder = text_field_embedder if contextualizer.is_bidirectional() is not bidirectional: raise ConfigurationError( "Bidirectionality of contextualizer must match bidirectionality of " "language model. " f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, " f"language model bidirectional: {bidirectional}") self._contextualizer = contextualizer self._bidirectional = bidirectional # The dimension for making predictions just in the forward # (or backward) direction. if self._bidirectional: self._forward_dim = contextualizer.get_output_dim() // 2 else: self._forward_dim = contextualizer.get_output_dim() if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size("transactions"), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=sparse_embeddings, ) else: self._softmax_loss = SoftmaxLoss( num_words=vocab.get_vocab_size("transactions"), embedding_dim=self._forward_dim, ) # This buffer is now unused and exists only for backwards compatibility reasons. self.register_buffer("_last_average_loss", torch.zeros(1)) self._perplexity = Perplexity() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x if initializer is not None: initializer(self)
def test_sampled_equals_unsampled_during_eval(self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=40) unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12) sampled_softmax.eval() unsampled_softmax.eval() # set weights equal, use transpose because opposite shapes sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t() sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be really close np.testing.assert_almost_equal(sampled_loss, full_loss)
def test_sampled_equals_unsampled_when_biased_against_non_sampled_positions( self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=10) unsampled_softmax = SoftmaxLoss(num_words=10000, embedding_dim=12) # fake out choice function FAKE_SAMPLES = [100, 200, 300, 400, 500, 600, 700, 800, 900, 9999] def fake_choice(num_words: int, num_samples: int) -> Tuple[np.ndarray, int]: assert (num_words, num_samples) == (10000, 10) return np.array(FAKE_SAMPLES), 12 sampled_softmax.choice_func = fake_choice # bias out the unsampled terms: for i in range(10000): if i not in FAKE_SAMPLES: unsampled_softmax.softmax_b[i] = -10000 # set weights equal, use transpose because opposite shapes sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t() sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b sampled_softmax.train() unsampled_softmax.train() # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100, )).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be close pct_error = (sampled_loss - full_loss) / full_loss assert abs(pct_error) < 0.001
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, dropout: float = None, loss_scale: Union[float, str] = 1.0, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = False, initializer: InitializerApplicator = None) -> None: super().__init__(vocab) self._text_field_embedder = text_field_embedder if contextualizer.is_bidirectional() is not bidirectional: raise ConfigurationError( "Bidirectionality of contextualizer must match bidirectionality of " "language model. " f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, " f"language model bidirectional: {bidirectional}") self._contextualizer = contextualizer self._bidirectional = bidirectional # The dimension for making predictions just in the forward # (or backward) direction. if self._bidirectional: self._forward_dim = contextualizer.get_output_dim() // 2 else: self._forward_dim = contextualizer.get_output_dim() # TODO(joelgrus): more sampled softmax configuration options, as needed. if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=sparse_embeddings) else: self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim) # TODO(brendanr): Output perplexity here. e^loss self.register_buffer('_last_average_loss', torch.zeros(1)) if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x self._loss_scale = loss_scale if initializer is not None: initializer(self)
def test_sampled_almost_equals_unsampled_when_num_samples_is_almost_all(self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=9999) unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12) # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be really close pct_error = (sampled_loss - full_loss) / full_loss assert abs(pct_error) < 0.02
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, forward_segmental_contextualizer: Seq2SeqEncoder, backward_segmental_contextualizer: Seq2SeqEncoder, label_feature_dim: int, softmax_projection_dim: int, label_namespace: str = "labels", dropout: float = None, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = True, initializer: InitializerApplicator = None) -> None: super().__init__(vocab=vocab, text_field_embedder=text_field_embedder, contextualizer=contextualizer, dropout=dropout, num_samples=num_samples, sparse_embeddings=sparse_embeddings, bidirectional=bidirectional, initializer=initializer) self._forward_segmental_contextualizer = forward_segmental_contextualizer self._backward_segmental_contextualizer = backward_segmental_contextualizer if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=softmax_projection_dim, num_samples=num_samples, sparse=sparse_embeddings) else: self._softmax_loss = _SoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=softmax_projection_dim) self.num_classes = self.vocab.get_vocab_size(label_namespace) self.label_feature_embedding = Embedding(self.num_classes, label_feature_dim) base_dim = contextualizer.get_output_dim() // 2 seg_dim = base_dim + label_feature_dim self._forward_dim = softmax_projection_dim self.pre_segmental_layer = TimeDistributed( Linear(seg_dim, softmax_projection_dim)) self.projection_layer = TimeDistributed( Linear(base_dim * 2, softmax_projection_dim))
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, hyperbolic_embedder: TextFieldEmbedder, hyperbolic_encoder: Seq2VecEncoder, hyperbolic_weight: float, is_baseline: bool = False, dropout: float = None, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = False, initializer: InitializerApplicator = None, ) -> None: super().__init__( vocab, text_field_embedder, contextualizer, dropout, num_samples, sparse_embeddings, bidirectional, initializer ) # reinitialize self._softmax_loss to change default namespace 'token' if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(namespace='euclidean'), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=sparse_embeddings, ) else: self._softmax_loss = SoftmaxLoss( num_words=vocab.get_vocab_size(namespace='euclidean'), embedding_dim=self._forward_dim ) # initialize hyperbolic components self._hyperbolic_embedder = hyperbolic_embedder self._hyperbolic_encoder = hyperbolic_encoder self._hyperbolic_encoding_loss = HyperbolicL1() self._hyperbolic_weight = hyperbolic_weight # vanila language mode self.is_baseline = is_baseline
def test_sampled_softmax_has_greater_loss_in_train_mode(self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=10) # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() sampled_softmax.train() train_loss = sampled_softmax(embedding, targets).item() sampled_softmax.eval() eval_loss = sampled_softmax(embedding, targets).item() assert eval_loss > train_loss
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, layer_norm: Optional[MaskedLayerNorm] = None, dropout: float = None, loss_scale: Union[float, str] = 1.0, remove_bos_eos: bool = True, num_samples: int = None, sparse_embeddings: bool = False) -> None: super().__init__(vocab) self._text_field_embedder = text_field_embedder self._layer_norm = layer_norm or (lambda x: x) if not contextualizer.is_bidirectional(): raise ConfigurationError("contextualizer must be bidirectional") self._contextualizer = contextualizer # The dimension for making predictions just in the forward # (or backward) direction. self._forward_dim = contextualizer.get_output_dim() // 2 # TODO(joelgrus): more sampled softmax configuration options, as needed. if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=sparse_embeddings) else: self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim) self.register_buffer('_last_average_loss', torch.zeros(1)) if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x self._loss_scale = loss_scale self._remove_bos_eos = remove_bos_eos
def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, dropout: float = None, tie_embeddings: bool = True, num_samples: int = None, use_variational_dropout: bool = False): super().__init__(vocab) self._embedder = embedder self._contextualizer = contextualizer self._context_dim = contextualizer.get_output_dim() if use_variational_dropout: self._dropout = InputVariationalDropout( dropout) if dropout else lambda x: x else: self._dropout = Dropout(dropout) if dropout else lambda x: x vocab_size = self.vocab.get_vocab_size() padding_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN) if num_samples: self._softmax_loss = SampledSoftmaxLoss(vocab_size, self._context_dim, num_samples) else: self._softmax_loss = SoftmaxLoss(vocab_size, self._context_dim, padding_index) self._tie_embeddings = tie_embeddings if self._tie_embeddings: embedder_children = dict(self._embedder.named_children()) word_embedder = embedder_children["token_embedder_tokens"] assert self._softmax_loss.softmax_w.size( ) == word_embedder.weight.size() self._softmax_loss.softmax_w = word_embedder.weight
def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, aux_contextualizer: Seq2SeqEncoder, dropout: float = None, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = False, initializer: InitializerApplicator = None, regularizer: Optional[RegularizerApplicator] = None, ) -> None: super().__init__(vocab, regularizer) self._text_field_embedder = text_field_embedder if contextualizer.is_bidirectional() is not bidirectional: raise ConfigurationError( "Bidirectionality of contextualizer must match bidirectionality of " "language model. " f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, " f"language model bidirectional: {bidirectional}") self._contextualizer_lang1 = aux_contextualizer self._contextualizer_lang2 = copy.deepcopy(aux_contextualizer) self._contextualizer = contextualizer self._bidirectional = bidirectional self._bidirectional_aux = aux_contextualizer.is_bidirectional() # The dimension for making predictions just in the forward # (or backward) direction. # main contextualizer forward dim if self._bidirectional: self._forward_dim = contextualizer.get_output_dim() // 2 else: self._forward_dim = contextualizer.get_output_dim() # aux contextualizer forward dim if self._bidirectional_aux: self._forward_dim_aux = aux_contextualizer.get_output_dim() // 2 else: self._forward_dim_aux = aux_contextualizer.get_output_dim() # TODO(joelgrus): more sampled softmax configuration options, as needed. if num_samples is not None: self._lang1_softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim_aux, num_samples=num_samples, sparse=sparse_embeddings, ) self._lang2_softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim_aux, num_samples=num_samples, sparse=sparse_embeddings, ) self._cm_softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=sparse_embeddings, ) else: self._lang1_softmax_loss = _SoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim_aux) self._lang2_softmax_loss = _SoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim_aux) self._cm_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim) # This buffer is now unused and exists only for backwards compatibility reasons. self.register_buffer("_last_average_loss", torch.zeros(1)) self._lang1_perplexity = Perplexity() self._lang2_perplexity = Perplexity() self._cm_perplexity = Perplexity() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x if initializer is not None: initializer(self)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, spellchecker_namespace: str = 'target_tokens', punct_namespace: str = 'punct_labels', feedforward: Optional[FeedForward] = None, punct_hidden: int = 256, embedding_dropout: Optional[float] = None, encoded_dropout: Optional[float] = None, punct_dropout: Optional[float] = None, punct_weight: Optional[Dict[str, float]] = None, num_samples: Optional[int] = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super().__init__(vocab, regularizer) self.label_namespace = punct_namespace self.text_field_embedder = text_field_embedder self.token_vocab_size = self.vocab.get_vocab_size( spellchecker_namespace) self.punct_vocab_size = self.vocab.get_vocab_size(punct_namespace) self.encoder = encoder self.embedding_dropout = Dropout( embedding_dropout) if embedding_dropout is not None else None self.encoded_dropout = Dropout( encoded_dropout) if encoded_dropout is not None else None self.feedforward = feedforward if feedforward is not None: self.output_dim = feedforward.get_output_dim() else: self.output_dim = self.encoder.get_output_dim() if punct_dropout is not None: self.punct_projection = Sequential( Linear(self.output_dim, punct_hidden), Dropout(punct_dropout), Linear(punct_hidden, self.punct_vocab_size)) else: self.punct_projection = Sequential( Linear(self.output_dim, punct_hidden), Linear(punct_hidden, self.punct_vocab_size)) self.losses = { 'spellchecker': SoftmaxLoss(num_words=self.token_vocab_size, embedding_dim=self.output_dim + 1) if num_samples is None else SampledSoftmaxLoss( num_words=self.token_vocab_size, embedding_dim=self.output_dim + 1, num_samples=num_samples), 'punct': CrossEntropyLoss(weight=self.__get_weight_tensor(punct_weight), reduction='sum', ignore_index=-1) } self.add_module('spellchecker_loss', self.losses['spellchecker']) self.add_module('punct_loss', self.losses['punct']) self.metrics = {'punct_accuracy': CategoricalAccuracy()} self.metrics.update({ f'f1_score_{name}': F1Measure( self.vocab.get_token_index(name, namespace=punct_namespace)) for name in self.vocab.get_token_to_index_vocabulary( namespace=punct_namespace) }) initializer(self)
class AclSequenceModel(Model): def __init__(self, vocab: Vocabulary, seq_embedder: TextFieldEmbedder, abstract_text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, calculate_recall: bool = False, use_abstracts: bool = True, use_node_vectors: bool = True, num_samples: int = None, dropout: float = None) -> None: super().__init__(vocab) self._abstract_text_field_embedder = abstract_text_field_embedder self._use_abstracts = use_abstracts self._use_node_vectors = use_node_vectors self._seq_embedder = seq_embedder self._calculate_recall = calculate_recall # lstm encoder uses PytorchSeq2SeqWrapper for pytorch lstm self._contextualizer = contextualizer self._forward_dim = contextualizer.get_output_dim() if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim, num_samples=num_samples, sparse=False) else: self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), embedding_dim=self._forward_dim) self._n_list = range(1, 50) self._recall_at_n = {} for n in self._n_list: self._recall_at_n[n] = RecallAtN(n) self._perplexity = Perplexity() if dropout: self._dropout = torch.nn.Dropout(dropout) else: self._dropout = lambda x: x def _compute_loss(self, lm_embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # Because the targets are offset by 1, we re-mask to # remove the final 0 in the targets mask = targets > 0 non_masked_targets = targets.masked_select(mask) - 1 non_masked_embeddings = lm_embeddings.masked_select( mask.unsqueeze(-1)).view(-1, self._forward_dim) return self._softmax_loss(non_masked_embeddings, non_masked_targets) def _compute_probs(self, lm_embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # Because the targets are offset by 1, we re-mask to # remove the final 0 in the targets mask = targets > 0 non_masked_targets = targets.masked_select(mask) - 1 non_masked_embeddings = lm_embeddings.masked_select( mask.unsqueeze(-1)).view(-1, self._forward_dim) return self._softmax_loss.probs(non_masked_embeddings) def num_layers(self) -> int: return self_contextualizer.num_layers + 1 def forward( self, abstracts: Dict[str, torch.LongTensor], paper_ids: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: """ Computes the loss from the batch. """ if self._use_abstracts and self._use_node_vectors: embeddings = torch.cat([ self._abstract_text_field_embedder(abstracts)[:, :, 0, :], self._seq_embedder(paper_ids) ], dim=-1) mask = get_text_field_mask(abstracts, num_wrapping_dims=1) mask = mask.sum(dim=-1) > 0 elif self._use_abstracts: embeddings = self._abstract_text_field_embedder(abstracts)[:, :, 0, :] mask = get_text_field_mask(abstracts, num_wrapping_dims=1) mask = mask.sum(dim=-1) > 0 elif self._use_node_vectors: embeddings = self._seq_embedder(paper_ids) mask = get_text_field_mask(paper_ids) else: # When use_node_vectors is false, the embedder should be configured # to initialize random embeddings. The redundant else condition # makes this difference in behavior a little more explicit, even # though the content of the block is identical. embeddings = self._seq_embedder(paper_ids) mask = get_text_field_mask(paper_ids) contextual_embeddings: Union[ torch.Tensor, List[torch.Tensor]] = self._contextualizer(embeddings, mask.long()) contextual_embeddings_with_dropout = self._dropout( contextual_embeddings) return_dict = {} assert isinstance(contextual_embeddings_with_dropout, torch.Tensor) # targets is like paper ids, but offset forward by 1 in the second # dimension. targets = torch.zeros_like(paper_ids['tokens']) targets[:, 0:targets.size()[1] - 1] = paper_ids['tokens'][:, 1:] loss = self._compute_loss(contextual_embeddings_with_dropout, targets) num_targets = torch.sum((targets > 0).long()) if num_targets > 0: average_loss = loss / num_targets.float() else: average_loss = torch.tensor(0.0).to(targets.device) perplexity = self._perplexity(average_loss) if self._calculate_recall: top_k = self.get_recall_at_n(contextual_embeddings, targets) return_dict.update({'top_k': top_k}) if num_targets > 0: return_dict.update({ 'loss': average_loss, 'batch_weight': num_targets.float() }) else: return_dict.update({'loss': average_loss}) return_dict.update({ 'lm_embeddings': contextual_embeddings, 'lm_targets': targets, 'noncontextual_embeddings': embeddings, }) return return_dict def get_recall_at_n(self, embeddings, targets): top_n = [] # iterate over batches: for embeddings, targets in zip(embeddings.detach(), targets.detach()): # (sequence_length, #targets) probs = self._compute_probs(embeddings, targets) top_probs, top_indices = probs.topk(k=max(self._n_list), dim=-1) top_ids = [[ self.vocab.get_token_from_index(int(i)) for i in top_n ] for top_n in top_indices] top_n.append(top_ids) mask = targets > 0 non_masked_targets = targets.masked_select(mask) - 1 for n in self._n_list: self._recall_at_n[n](non_masked_targets, top_indices) return top_n def get_metrics(self, reset: bool = False): metrics = {"perplexity": self._perplexity.get_metric(reset=reset)} if self._calculate_recall: for n in self._n_list: recall = self._recall_at_n[n].get_metric(reset=reset) metrics.update({"recall_at_{}".format(n): recall}) return metrics