def test_apply_scaling_temperature(self):
        batch_shape = [7, 11]
        nclasses = 5
        logits = tf.random.uniform(batch_shape + [nclasses], 0, 10)
        temperature = 3

        probs = tf.math.softmax(logits, axis=-1)
        expected = tf.math.softmax(logits / temperature, axis=-1)
        actual = calibration_lib.apply_temperature_scaling(temperature, probs)
        self.assertAllEqual(actual.shape, logits.shape)
        self.assertAllClose(expected, actual)
Example #2
0
def main(_):

  if FLAGS.is_tempscale:
    tf.enable_v2_behavior()

  params = {
      'num_epochs': FLAGS.num_epochs,
      'fix_len': FLAGS.fix_len,
      'batch_size': FLAGS.batch_size,
      'n_class': FLAGS.n_class,
      'emb_size': FLAGS.emb_size,
      'vocab_size': FLAGS.vocab_size,
      'hidden_lstm_size': FLAGS.hidden_lstm_size,
      'dropout_rate': FLAGS.dropout_rate,
      'dropout_rate_lstm': FLAGS.dropout_rate_lstm,
      'learning_rate': FLAGS.learning_rate,
      'reg_weight': FLAGS.reg_weight,
      'tr_out_dir': FLAGS.tr_out_dir,
      'data_pkl_file': FLAGS.data_pkl_file,
      'master': FLAGS.master,
      'clip_norm': FLAGS.clip_norm,
      'random_seed': FLAGS.random_seed,
      'variational': FLAGS.variational,
      'n_class_in': None,
      'n_train': None,
  }

  # load in-dist. and skewed in-dist. datasets
  data = classifier.load_np_dataset(params['data_pkl_file'])

  # load OOD dataset
  n_ood = 5600
  test_lm1b_x_pad, _ = load_ood_dataset(n_ood, params['fix_len'], data.vocab,
                                        params['vocab_size'])

  # list of ckpt dir
  model_dir = os.path.join(FLAGS.model_dir, FLAGS.method)
  ckpt_dirs = tf.io.gfile.listdir(model_dir)

  # how many replicates for ensemble
  if FLAGS.is_ensemble:
    assert len(ckpt_dirs) > 1
    n_ensemble = len(ckpt_dirs)
    if n_ensemble == 0:
      logging.fatal('no model ckpt')
  else:
    n_ensemble = 1

  pred = {}  # dict for final prediction score
  # dict for saving pred from different models
  pred_accum = {'in': [], 'skew': [], 'ood': []}

  for i in range(n_ensemble):

    ckpt_dir = os.path.join(model_dir, ckpt_dirs[i], 'model')
    if not tf.io.gfile.isdir(ckpt_dir):
      continue
    print('ckpt_dir={}'.format(ckpt_dir))

    # load params
    with tf.gfile.GFile(os.path.join(ckpt_dir, 'params.json'), mode='rb') as f:
      params_json = yaml.safe_load(f)
      params.update(params_json)
      params['master'] = ''
    print('params after load={}'.format(params))

    tf.reset_default_graph()
    # create model
    model = classifier.rnn_model(
        params,
        training_dr_lstm=params['dropout_rate_lstm'] != 0.0,
        training_dr_ll=params['dropout_rate'] != 0.0)

    # load model
    model.load_weights(ckpt_dir + '/model.ckpt')

    # predict
    if FLAGS.method in ['ll-svi', 'dropout', 'll-dropout']:
      # need to run multiple times and get mean prediction
      assert FLAGS.n_pred_sample > 1
    else:
      FLAGS.n_pred_sample = 1

    pred_k = {'in': [], 'skew': [], 'ood': []}
    for _ in range(FLAGS.n_pred_sample):
      pred_tr_in = model.predict(data.in_sample_examples)
      acc_tr_in = np.mean(
          data.in_sample_labels == np.argmax(pred_tr_in, axis=1))

      pred_test_in = model.predict(data.test_in_sample_examples)
      acc_test_in = np.mean(
          data.test_in_sample_labels == np.argmax(pred_test_in, axis=1))
      print('in-dist. acc_tr={}, acc_test={}'.format(acc_tr_in, acc_test_in))

      pred_test_skew = model.predict(data.test_oos_examples)
      pred_test_ood = model.predict(test_lm1b_x_pad)

      if FLAGS.is_tempscale:
        # temperature scaling
        # logits for temp scaling
        last_layer_model = models.Model(
            inputs=model.input, outputs=model.get_layer('last_layer').output)
        logits = last_layer_model.predict(data.dev_in_sample_examples)
        opt_temp = calibration_lib.find_scaling_temperature(
            data.dev_in_sample_labels, logits, temp_range=(1e-5, 1e5))
        pred_test_in = calibration_lib.apply_temperature_scaling(
            opt_temp, pred_test_in)
        pred_test_skew = calibration_lib.apply_temperature_scaling(
            opt_temp, pred_test_skew)
        pred_test_ood = calibration_lib.apply_temperature_scaling(
            opt_temp, pred_test_ood)

      # save in a list
      pred_k['in'].append(pred_test_in)
      pred_k['skew'].append(pred_test_skew)
      pred_k['ood'].append(pred_test_ood)

    pred_k_in_mean = np.mean(np.stack(pred_k['in']), axis=0)
    pred_k_skew_mean = np.mean(np.stack(pred_k['skew']), axis=0)
    pred_k_ood_mean = np.mean(np.stack(pred_k['ood']), axis=0)

    pred_accum['in'].append(pred_k_in_mean)
    pred_accum['skew'].append(pred_k_skew_mean)
    pred_accum['ood'].append(pred_k_ood_mean)

  # if ensemble, then take the mean
  pred['in'] = np.mean(np.stack(pred_accum['in']), axis=0)
  pred['skew'] = np.mean(np.stack(pred_accum['skew']), axis=0)
  pred['ood'] = np.mean(np.stack(pred_accum['ood']), axis=0)

  # prediction accuracy for in-dist.
  pred['in_true_labels'] = data.test_in_sample_labels
  acc = np.mean(data.test_in_sample_labels == np.argmax(pred['in'], axis=1))
  print('== (optionally ensemble) acc={} =='.format(acc))

  print('== eval in and skew using max(Py|x) ==')
  neg = list(np.max(pred['in'], axis=1))
  pos = list(np.max(pred['skew'], axis=1))
  print('auc={}'.format(compute_auc(neg, pos, pos_label=0)))

  print('== eval in and ood using max(Py|x) ==')
  neg = list(np.max(pred['in'], axis=1))
  pos = list(np.max(pred['ood'], axis=1))
  print('auc={}'.format(compute_auc(neg, pos, pos_label=0)))

  # save the predictions
  pred_file_name = 'pred_nensemb{}_npred{}_tempscale{}.pkl'.format(
      len(pred_accum['in']), FLAGS.n_pred_sample, FLAGS.is_tempscale)
  with tf.gfile.Open(os.path.join(model_dir, pred_file_name), 'wb') as f:
    pickle.dump(pred, f, protocol=2)