def test_find_scaling_temperature(self): nsamples, nclasses = 10**4, 5 logits = tf.random.uniform([nsamples, nclasses], 0, 300) target_temperature = 100 scaled_logits = logits / target_temperature labels = tfp.distributions.Categorical(logits=scaled_logits).sample() temperature = calibration_lib.find_scaling_temperature(labels, logits) logging.info('temperature=%0.3f, target=%0.3f', temperature, target_temperature) rel_error = (temperature - target_temperature) / target_temperature self.assertAlmostEqual(0, rel_error, places=1)
def test_find_scaling_temperature_invalid_input(self): nsamples, nclasses = 10**4, 5 logits = np.ones([nsamples, nclasses]) labels = np.ones([nsamples]) with self.assertRaises(ValueError): calibration_lib.find_scaling_temperature(labels[None], logits) with self.assertRaises(ValueError): calibration_lib.find_scaling_temperature(labels, logits[None]) with self.assertRaises(ValueError): calibration_lib.find_scaling_temperature(labels, logits.T)
def run(prediction_path): """Run temperature scaling.""" stats = array_utils.load_stats_from_tfrecords(prediction_path) probs = stats['probs'].astype(np.float32) labels = stats['labels'].astype(np.int32) if len(labels.shape) > 1: labels = np.squeeze(labels, -1) if probs.shape[0] > NUM_EXAMPLES: probs = probs[:NUM_EXAMPLES, :] labels = labels[:NUM_EXAMPLES] probs = metrics_lib.soften_probabilities(probs=probs) logits = uq_utils.np_inverse_softmax(probs) temp = calibration_lib.find_scaling_temperature(labels, logits) with gfile.GFile( os.path.join(os.path.dirname(prediction_path), 'temperature_hparam.json'), 'w') as fh: fh.write(json.dumps({'temperature': temp}))
def main(_): if FLAGS.is_tempscale: tf.enable_v2_behavior() params = { 'num_epochs': FLAGS.num_epochs, 'fix_len': FLAGS.fix_len, 'batch_size': FLAGS.batch_size, 'n_class': FLAGS.n_class, 'emb_size': FLAGS.emb_size, 'vocab_size': FLAGS.vocab_size, 'hidden_lstm_size': FLAGS.hidden_lstm_size, 'dropout_rate': FLAGS.dropout_rate, 'dropout_rate_lstm': FLAGS.dropout_rate_lstm, 'learning_rate': FLAGS.learning_rate, 'reg_weight': FLAGS.reg_weight, 'tr_out_dir': FLAGS.tr_out_dir, 'data_pkl_file': FLAGS.data_pkl_file, 'master': FLAGS.master, 'clip_norm': FLAGS.clip_norm, 'random_seed': FLAGS.random_seed, 'variational': FLAGS.variational, 'n_class_in': None, 'n_train': None, } # load in-dist. and skewed in-dist. datasets data = classifier.load_np_dataset(params['data_pkl_file']) # load OOD dataset n_ood = 5600 test_lm1b_x_pad, _ = load_ood_dataset(n_ood, params['fix_len'], data.vocab, params['vocab_size']) # list of ckpt dir model_dir = os.path.join(FLAGS.model_dir, FLAGS.method) ckpt_dirs = tf.io.gfile.listdir(model_dir) # how many replicates for ensemble if FLAGS.is_ensemble: assert len(ckpt_dirs) > 1 n_ensemble = len(ckpt_dirs) if n_ensemble == 0: logging.fatal('no model ckpt') else: n_ensemble = 1 pred = {} # dict for final prediction score # dict for saving pred from different models pred_accum = {'in': [], 'skew': [], 'ood': []} for i in range(n_ensemble): ckpt_dir = os.path.join(model_dir, ckpt_dirs[i], 'model') if not tf.io.gfile.isdir(ckpt_dir): continue print('ckpt_dir={}'.format(ckpt_dir)) # load params with tf.gfile.GFile(os.path.join(ckpt_dir, 'params.json'), mode='rb') as f: params_json = yaml.safe_load(f) params.update(params_json) params['master'] = '' print('params after load={}'.format(params)) tf.reset_default_graph() # create model model = classifier.rnn_model( params, training_dr_lstm=params['dropout_rate_lstm'] != 0.0, training_dr_ll=params['dropout_rate'] != 0.0) # load model model.load_weights(ckpt_dir + '/model.ckpt') # predict if FLAGS.method in ['ll-svi', 'dropout', 'll-dropout']: # need to run multiple times and get mean prediction assert FLAGS.n_pred_sample > 1 else: FLAGS.n_pred_sample = 1 pred_k = {'in': [], 'skew': [], 'ood': []} for _ in range(FLAGS.n_pred_sample): pred_tr_in = model.predict(data.in_sample_examples) acc_tr_in = np.mean( data.in_sample_labels == np.argmax(pred_tr_in, axis=1)) pred_test_in = model.predict(data.test_in_sample_examples) acc_test_in = np.mean( data.test_in_sample_labels == np.argmax(pred_test_in, axis=1)) print('in-dist. acc_tr={}, acc_test={}'.format(acc_tr_in, acc_test_in)) pred_test_skew = model.predict(data.test_oos_examples) pred_test_ood = model.predict(test_lm1b_x_pad) if FLAGS.is_tempscale: # temperature scaling # logits for temp scaling last_layer_model = models.Model( inputs=model.input, outputs=model.get_layer('last_layer').output) logits = last_layer_model.predict(data.dev_in_sample_examples) opt_temp = calibration_lib.find_scaling_temperature( data.dev_in_sample_labels, logits, temp_range=(1e-5, 1e5)) pred_test_in = calibration_lib.apply_temperature_scaling( opt_temp, pred_test_in) pred_test_skew = calibration_lib.apply_temperature_scaling( opt_temp, pred_test_skew) pred_test_ood = calibration_lib.apply_temperature_scaling( opt_temp, pred_test_ood) # save in a list pred_k['in'].append(pred_test_in) pred_k['skew'].append(pred_test_skew) pred_k['ood'].append(pred_test_ood) pred_k_in_mean = np.mean(np.stack(pred_k['in']), axis=0) pred_k_skew_mean = np.mean(np.stack(pred_k['skew']), axis=0) pred_k_ood_mean = np.mean(np.stack(pred_k['ood']), axis=0) pred_accum['in'].append(pred_k_in_mean) pred_accum['skew'].append(pred_k_skew_mean) pred_accum['ood'].append(pred_k_ood_mean) # if ensemble, then take the mean pred['in'] = np.mean(np.stack(pred_accum['in']), axis=0) pred['skew'] = np.mean(np.stack(pred_accum['skew']), axis=0) pred['ood'] = np.mean(np.stack(pred_accum['ood']), axis=0) # prediction accuracy for in-dist. pred['in_true_labels'] = data.test_in_sample_labels acc = np.mean(data.test_in_sample_labels == np.argmax(pred['in'], axis=1)) print('== (optionally ensemble) acc={} =='.format(acc)) print('== eval in and skew using max(Py|x) ==') neg = list(np.max(pred['in'], axis=1)) pos = list(np.max(pred['skew'], axis=1)) print('auc={}'.format(compute_auc(neg, pos, pos_label=0))) print('== eval in and ood using max(Py|x) ==') neg = list(np.max(pred['in'], axis=1)) pos = list(np.max(pred['ood'], axis=1)) print('auc={}'.format(compute_auc(neg, pos, pos_label=0))) # save the predictions pred_file_name = 'pred_nensemb{}_npred{}_tempscale{}.pkl'.format( len(pred_accum['in']), FLAGS.n_pred_sample, FLAGS.is_tempscale) with tf.gfile.Open(os.path.join(model_dir, pred_file_name), 'wb') as f: pickle.dump(pred, f, protocol=2)