def get_rejection_thresholds( it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float] ): num_bins = 1000 # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)] num_correct = torch.zeros(num_bins) num_total = torch.zeros(num_bins) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask).cpu() targets = targets.masked_select(mask) is_corrects = (targets == best_predictions).cpu() num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1)) num_correct.add_( torch.histc( reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1 ) ) def precision(stat): if stat.total == 0: return 0 return stat.correct * 1.0 / stat.total thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds] rolling_stat = SimpleNamespace(correct=0, total=0) for i, correct, total in zip( itertools.count(), num_correct.numpy(), num_total.numpy() ): for t, precision_threshold in zip(thresholds, precision_thresholds): if precision_threshold <= precision(rolling_stat): # update threshold if it's not set or the number of samples increased if t.h is None or t.size * 1.01 < rolling_stat.total: t.h = i / float(num_bins) t.size = int(rolling_stat.total) rolling_stat.correct += correct rolling_stat.total += total Logger.debug( "Thresholds: {}, sizes: {}".format( [t.h for t in thresholds], [t.size for t in thresholds] ) ) return thresholds
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset): num_correct = 0 num_total = 0 thresholds = np.arange(0.1, 1.1, 0.1) stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0)) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask) targets = targets.masked_select(mask) is_correct = targets == best_predictions num_correct += torch.sum(is_correct).item() num_total += targets.numel() for h in thresholds: h_mask = reject_probs <= h stats[h].total += torch.sum(h_mask).item() stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item() for h in thresholds: Logger.debug( "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format( h, stats[h].correct, stats[h].total, acc(stats[h].correct, stats[h].total), ) ) Logger.debug( "{:6d}/{:6d} ({:.2f}%)".format( num_correct, num_total, acc(num_correct, num_total) ) )