def test_expected_data_uncertainty_with_logits(self):
     _, _, output = um.model_uncertainty(logits=self.logits)
     msg = "Data uncertainty not almost equal what is expected."
     self.assertTensorsAlmostEqual(self.expected_data_uncertainty,
                                   output,
                                   msg=msg)
 def test_wrong_logits_shape(self):
     logits_wrong = tf.reduce_mean(self.logits, -1)
     with self.assertRaises(ValueError):
         um.model_uncertainty(logits=logits_wrong)
 def test_total_uncertainty_with_logits(self):
     _, output, _ = um.model_uncertainty(logits=self.logits)
     msg = "Total uncertainty not almost equal what is expected."
     self.assertTensorsAlmostEqual(self.total_uncertainty, output, msg=msg)