def test_sampled_almost_equals_unsampled_when_num_samples_is_almost_all(self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=9999) unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12) # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be really close pct_error = (sampled_loss - full_loss) / full_loss assert abs(pct_error) < 0.02
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, contextualizer: Seq2SeqEncoder, forward_segmental_contextualizer: Seq2SeqEncoder, backward_segmental_contextualizer: Seq2SeqEncoder, label_feature_dim: int, softmax_projection_dim: int, label_namespace: str = "labels", dropout: float = None, num_samples: int = None, sparse_embeddings: bool = False, bidirectional: bool = True, initializer: InitializerApplicator = None) -> None: super().__init__(vocab=vocab, text_field_embedder=text_field_embedder, contextualizer=contextualizer, dropout=dropout, num_samples=num_samples, sparse_embeddings=sparse_embeddings, bidirectional=bidirectional, initializer=initializer) self._forward_segmental_contextualizer = forward_segmental_contextualizer self._backward_segmental_contextualizer = backward_segmental_contextualizer if num_samples is not None: self._softmax_loss = SampledSoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=softmax_projection_dim, num_samples=num_samples, sparse=sparse_embeddings) else: self._softmax_loss = _SoftmaxLoss( num_words=vocab.get_vocab_size(), embedding_dim=softmax_projection_dim) self.num_classes = self.vocab.get_vocab_size(label_namespace) self.label_feature_embedding = Embedding(self.num_classes, label_feature_dim) base_dim = contextualizer.get_output_dim() // 2 seg_dim = base_dim + label_feature_dim self._forward_dim = softmax_projection_dim self.pre_segmental_layer = TimeDistributed( Linear(seg_dim, softmax_projection_dim)) self.projection_layer = TimeDistributed( Linear(base_dim * 2, softmax_projection_dim))
def test_sampled_equals_unsampled_during_eval(self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=40) unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12) sampled_softmax.eval() unsampled_softmax.eval() # set weights equal, use transpose because opposite shapes sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t() sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100,)).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be really close np.testing.assert_almost_equal(sampled_loss, full_loss)
def test_sampled_equals_unsampled_when_biased_against_non_sampled_positions( self): sampled_softmax = SampledSoftmaxLoss(num_words=10000, embedding_dim=12, num_samples=10) unsampled_softmax = _SoftmaxLoss(num_words=10000, embedding_dim=12) # fake out choice function FAKE_SAMPLES = [100, 200, 300, 400, 500, 600, 700, 800, 900, 9999] def fake_choice(num_words: int, num_samples: int) -> Tuple[np.ndarray, int]: assert (num_words, num_samples) == (10000, 10) return np.array(FAKE_SAMPLES), 12 sampled_softmax.choice_func = fake_choice # bias out the unsampled terms: for i in range(10000): if i not in FAKE_SAMPLES: unsampled_softmax.softmax_b[i] = -10000 # set weights equal, use transpose because opposite shapes sampled_softmax.softmax_w.data = unsampled_softmax.softmax_w.t() sampled_softmax.softmax_b.data = unsampled_softmax.softmax_b sampled_softmax.train() unsampled_softmax.train() # sequence_length, embedding_dim embedding = torch.rand(100, 12) targets = torch.randint(0, 1000, (100, )).long() full_loss = unsampled_softmax(embedding, targets).item() sampled_loss = sampled_softmax(embedding, targets).item() # Should be close pct_error = (sampled_loss - full_loss) / full_loss assert abs(pct_error) < 0.001