コード例 #1
0
def main():
    # Make synthetic dataset.
    np.random.seed(0)  # Keep results consistent.
    num_points = 1000
    (zs, ys) = synthetic_data_1d(num_points=num_points)

    # Estimate a lower bound on the calibration error.
    # Here z_i is the confidence of the uncalibrated model, y_i is the true label.
    calibration_error = calibration.get_calibration_error(zs, ys)
    print("Uncalibrated model calibration error is > %.2f%%" %
          (100 * calibration_error))

    # Estimate the ECE.
    ece = calibration.get_ece(zs, ys)
    print("Uncalibrated model ECE is > %.2f%%" % (100 * ece))

    # Use Platt binning to train a recalibrator.
    calibrator = calibration.PlattBinnerCalibrator(num_points, num_bins=10)
    calibrator.train_calibration(np.array(zs), ys)

    # Measure the calibration error of recalibrated model.
    (test_zs, test_ys) = synthetic_data_1d(num_points=num_points)
    calibrated_zs = calibrator.calibrate(test_zs)
    calibration_error = calibration.get_calibration_error(
        calibrated_zs, test_ys)
    print("Scaling-binning L2 calibration error is %.2f%%" %
          (100 * calibration_error))

    # Get confidence intervals for the calibration error.
    [lower, _, upper] = calibration.get_calibration_error_uncertainties(
        calibrated_zs, test_ys)
    print("  Confidence interval is [%.2f%%, %.2f%%]" %
          (100 * lower, 100 * upper))
コード例 #2
0
 def on_epoch_end(self, data: Data) -> None:
     self.y_true = np.squeeze(np.stack(self.y_true))
     self.y_pred = np.stack(self.y_pred)
     mid = round(
         cal.get_calibration_error(probs=self.y_pred,
                                   labels=self.y_true,
                                   mode=self.method), 4)
     low = None
     high = None
     if self.confidence_interval is not None:
         low, _, high = cal.get_calibration_error_uncertainties(
             probs=self.y_pred,
             labels=self.y_true,
             mode=self.method,
             alpha=self.confidence_interval)
         low = round(low, 4)
         high = round(high, 4)
     data.write_with_log(
         self.outputs[0],
         ValWithError(low, mid, high) if low is not None else mid)