def update_state(self, y_true, y_pred, sample_weight=None): if type(y_true)==np.ndarray: y_true = tf.constant(y_true) if type(y_pred)==np.ndarray: y_pred = tf.constant(y_pred) sh = tf.shape(y_true) if len(sh)>1: y_true = tf.reshape(y_true,sh[:1]) y_true = tf.cast(y_true,dtype=tf.int32) brier_score = 1 + um.brier_score(labels=y_true, probabilities=y_pred) super(BrierScore, self).update_state(brier_score)
def test_brier_decomposition(self, temperature, nlabels, nsamples): """Test the accuracy of the estimated Brier decomposition.""" tf.random.set_seed(1) logits = tf.random.normal((nsamples, nlabels)) / temperature labels = tf.random.uniform((nsamples, ), maxval=nlabels, dtype=tf.int32) uncertainty, resolution, reliability = um.brier_decomposition( labels=labels, logits=logits) uncertainty = float(uncertainty) resolution = float(resolution) reliability = float(reliability) # Recover an estimate of the Brier score from the decomposition brier = uncertainty - resolution + reliability # Estimate Brier score directly brier_direct = um.brier_score(labels=labels, logits=logits) brier_direct = float(brier_direct) logging.info( "Brier, n=%d k=%d T=%.2f, Unc %.4f - Res %.4f + Rel %.4f = " "Brier %.4f, Brier-direct %.4f", nsamples, nlabels, temperature, uncertainty, resolution, reliability, brier, brier_direct) self.assertGreaterEqual(resolution, 0.0, msg="Brier resolution negative") self.assertGreaterEqual(reliability, 0.0, msg="Brier reliability negative") self.assertAlmostEqual( brier, brier_direct, delta=1.0e-2, msg= "Brier from decomposition (%.4f) and Brier direct (%.4f) disagree " "beyond estimation error." % (brier, brier_direct))
def update_state(self, y_true, y_pred, sample_weight=None): brier_score = um.brier_score(labels=y_true, probabilities=y_pred) super(BrierScore, self).update_state(brier_score)