Example #1
0
 def test_constant_shift(self):
     """Test if adding a constant to all activations is vacuous."""
     labels = jnp.array([[0.2, 0.3, 0.5], [0.4, 0.4, 0.2], [0.7, 0.2, 0.1]])
     rng = random.PRNGKey(seed=1335)
     rng, use_key = random.split(rng)
     activations = random.normal(use_key, shape=[3, 3])
     bias = random.normal(rng, shape=[3, 1])
     for t2 in [0.8, 1.2]:
         actual_loss = loss.bi_tempered_logistic_loss(
             activations, labels, 0.5, t2)
         shifted_loss = loss.bi_tempered_logistic_loss(
             activations + bias, labels, 0.5, t2)
         npt.assert_allclose(actual_loss, shifted_loss, atol=1e-6)
Example #2
0
 def test_loss_value(self):
   """Test the loss based on precomputed values."""
   labels = jnp.array([[0.2, 0.3, 0.5], [0.6, 0.3, 0.1], [0.2, 0.8, 0.0]])
   activations = jnp.array([[-0.5, 0.1, 2.0], [0.1, 1.5, -5.0],
                            [4.0, -3.0, -6.0]])
   actual_loss = loss.bi_tempered_logistic_loss(activations, labels, 0.5, 1.5)
   npt.assert_allclose(actual_loss,
                       jnp.array([0.02301914, 0.18972909, 0.93874922]),
                       atol=1e-4)
   actual_loss = loss.bi_tempered_logistic_loss(activations, labels, 0.5,
                                                0.8, num_iters=20)
   npt.assert_allclose(actual_loss,
                       jnp.array([0.21646356, 0.41836615, 1.33997854]),
                       atol=1e-4)
Example #3
0
 def test_limit_case_logistic_loss(self):
     """Test for checking if t1 = t2 = 1.0 yields the logistic loss."""
     labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
     rng = random.PRNGKey(seed=1335)
     activations = random.normal(rng, shape=[3, 3])
     actual_loss = loss.bi_tempered_logistic_loss(activations, labels, 1.0,
                                                  1.0)
     logistic_loss = loss._cross_entropy_loss(logits=activations,
                                              labels=labels)
     npt.assert_allclose(actual_loss, logistic_loss)
Example #4
0
 def test_label_smoothing(self):
   """Test label smoothing."""
   labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
   activations = jnp.array([[-0.5, 0.1, 2.0], [0.1, 1.5, -5.0],
                            [4.0, -3.0, -6.0]])
   actual_loss = loss.bi_tempered_logistic_loss(
       activations, labels, 0.5, 1.5, label_smoothing=0.1)
   npt.assert_allclose(
       actual_loss, jnp.array([0.76652711, 0.08627685, 1.35443510]),
       atol=1e-5)
Example #5
0
 def test_dynamic_temperatures(self):
   """Test changing temperatures dynamically."""
   labels = jnp.array([[0.2, 0.5, 0.3]])
   activations = jnp.array([[-0.5, 0.1, 2.0]])
   t1_values = [1.0, 0.9, 0.8, 0.7]
   t2_values = [1.0, 1.1, 1.2, 1.3]
   loss_values = [[0.628705], [0.45677936], [0.34298314], [0.26295574]]
   loss_out = []
   for t1_value, t2_value in zip(t1_values, t2_values):
     loss_out.append(loss.bi_tempered_logistic_loss(
         activations, labels, t1_value, t2_value, num_iters=5))
   npt.assert_allclose(loss_values, loss_out, atol=1e-5)