示例#1
0
def get_rejection_thresholds(
    it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float]
):
    num_bins = 1000
    # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)]

    num_correct = torch.zeros(num_bins)
    num_total = torch.zeros(num_bins)
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask).cpu()
        targets = targets.masked_select(mask)

        is_corrects = (targets == best_predictions).cpu()

        num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1))
        num_correct.add_(
            torch.histc(
                reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1
            )
        )

    def precision(stat):
        if stat.total == 0:
            return 0
        return stat.correct * 1.0 / stat.total

    thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds]
    rolling_stat = SimpleNamespace(correct=0, total=0)
    for i, correct, total in zip(
        itertools.count(), num_correct.numpy(), num_total.numpy()
    ):
        for t, precision_threshold in zip(thresholds, precision_thresholds):
            if precision_threshold <= precision(rolling_stat):
                # update threshold if it's not set or the number of samples increased
                if t.h is None or t.size * 1.01 < rolling_stat.total:
                    t.h = i / float(num_bins)
                    t.size = int(rolling_stat.total)

        rolling_stat.correct += correct
        rolling_stat.total += total

    Logger.debug(
        "Thresholds: {}, sizes: {}".format(
            [t.h for t in thresholds], [t.size for t in thresholds]
        )
    )
    return thresholds
示例#2
0
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset):
    num_correct = 0
    num_total = 0
    thresholds = np.arange(0.1, 1.1, 0.1)
    stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0))
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask)
        targets = targets.masked_select(mask)

        is_correct = targets == best_predictions
        num_correct += torch.sum(is_correct).item()
        num_total += targets.numel()

        for h in thresholds:
            h_mask = reject_probs <= h
            stats[h].total += torch.sum(h_mask).item()
            stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item()

    for h in thresholds:
        Logger.debug(
            "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format(
                h,
                stats[h].correct,
                stats[h].total,
                acc(stats[h].correct, stats[h].total),
            )
        )

    Logger.debug(
        "{:6d}/{:6d} ({:.2f}%)".format(
            num_correct, num_total, acc(num_correct, num_total)
        )
    )