Example #1
0
def calculate_f1_metric(
    metric: F1MetricMethodName,
    model,
    test_loader,
    sp: spm.SentencePieceProcessor,
    use_cuda=True,
    beam_search_k=8,
    max_decode_len=10,  # see empirical evaluation of CDF of subwork token lengths
    logger_fn=None,
):
    with Timer() as t:
        sample_generations = []
        n_examples = 0
        precision, recall, f1 = 0.0, 0.0, 0.0
        pbar = tqdm.tqdm(test_loader, desc="test")
        for X, Y, X_lengths, Y_lengths in pbar:
            if use_cuda:
                X, Y = X.cuda(), Y.cuda()  # B, L
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            with Timer() as t:
                # pred, scores = beam_search_decode(model, X, sp, k=beam_search_k, max_decode_len=max_decode_len)
                pred, scores = beam_search_decode(
                    model,
                    X,
                    X_lengths,
                    sp,
                    k=beam_search_k,
                    max_decode_len=max_decode_len)
            logger.info(
                f"Took {t.interval:.2f}s to decode {X.size(0)} identifiers")
            for i in range(X.size(0)):
                gt_identifier = ids_to_strs(Y[i], sp)
                top_beam = pred[i][0]
                pred_dict = {"gt": gt_identifier}
                for i, beam_result in enumerate(pred[i]):
                    pred_dict[f"pred_{i}"] = beam_result
                sample_generations.append(pred_dict)
                precision_item, score_item, f1_item = metric(
                    top_beam, gt_identifier)
                precision += precision_item
                recall += score_item
                f1 += f1_item
                n_examples += 1
                if logger_fn is not None:
                    logger_fn({
                        "precision_item": precision_item,
                        "recall_item": score_item,
                        "f1_item": f1_item
                    })
                    logger_fn({
                        "precision_avg": precision / n_examples,
                        "recall_avg": recall / n_examples,
                        "f1_avg": f1 / n_examples
                    })
    logger.debug(
        f"Test set evaluation (F1) took {t.interval:.3}s over {n_examples} samples"
    )
    return precision / n_examples, recall / n_examples, f1 / n_examples, sample_generations
Example #2
0
def calculate_nll(
    model,
    test_loader,
    sp: spm.SentencePieceProcessor,
    use_cuda=True,
    logger_fn=None
):
    with Timer() as t:
        pad_id = sp.PieceToId("[PAD]")
        n_examples = 0
        test_nll = 0.
        pbar = tqdm.tqdm(test_loader, desc="test")
        for X, Y, X_lengths, Y_lengths in pbar:
            B, L = X.shape
            if use_cuda:
                X, Y = X.cuda(), Y.cuda()  # B, L
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            pred_y = model(X, Y[:, :-1].to(X.device), X_lengths, Y_lengths)
            B, X, D = pred_y.shape
            loss = F.cross_entropy(pred_y.reshape(B * X, D), Y[:, 1:].reshape(B * X), ignore_index=pad_id, reduction='sum')
            
            n_examples += B
            test_nll += loss.item()
            if logger_fn is not None:
                logger_fn({'test_nll': loss.item() / B, 'test_nll_avg': test_nll / n_examples})
        return test_nll / n_examples
Example #3
0
 def __setstate__(self, state):
     with Timer() as t:
         (self.spm_unigram_path, self.subword_regularization_alpha,
          self.max_length) = state
         self.sp_model = load_sp_model(self.spm_unigram_path)
         self.bos_id = self.sp_model.PieceToId("<s>")
         self.eos_id = self.sp_model.PieceToId("</s>")
         self.pad_id = self.sp_model.PieceToId("[PAD]")
     logger.info("Hydrating vocabulary took {:.3f}s".format(t.interval))
Example #4
0
def _evaluate(
    model, loader, sp: spm.SentencePieceProcessor, use_cuda=True, num_to_print=8, beam_search_k=5, max_decode_len=20, loss_type="nll_token"
):
    model.eval()
    pad_id = sp.PieceToId("[PAD]")

    with torch.no_grad():
        # with Timer() as t:
        #     # Decode a single batch by beam search for visualization
        #     X, Y, X_lengths, _ = next(iter(loader))
        #     X, Y = X[:num_to_print], Y[:num_to_print]
        #     if use_cuda:
        #         X = X.cuda()
        #         X_lengths = X.cuda()
        #     pred, scores = beam_search_decode(model, X, X_lengths, sp, k=beam_search_k, max_decode_len=max_decode_len)
        #     for i in range(X.size(0)):
        #         logger.info(f"Eval X:   \t\t\t{ids_to_strs(X[i], sp)}")
        #         logger.info(f"Eval GT Y:\t\t\t{ids_to_strs(Y[i], sp)}")
        #         for b in range(scores.size(1)):
        #             logger.info(f"Eval beam (score={scores[i, b]:.3f}):\t{pred[i][b]}")
        # logger.debug(f"Decode time for {num_to_print} samples took {t.interval:.3f}")

        with Timer() as t:
            # Compute average loss
            total_loss = 0
            num_examples = 0
            pbar = tqdm.tqdm(loader, desc="evalaute")
            for X, Y, X_lengths, Y_lengths in pbar:
                if use_cuda:
                    X, Y = X.cuda(), Y.cuda()
                    X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
                # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
                logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
                if loss_type == "nll_sequence":
                    loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction="sum")
                    loss = loss / X.size(0)  # Average over num sequences, not target sequence lengths
                    # Thus, minimize bits per sequence.
                elif loss_type == "nll_token":
                    loss = F.cross_entropy(
                        logits.transpose(1, 2),
                        Y[:, 1:],
                        ignore_index=pad_id,
                    )

                # TODO: Compute Precision/Recall/F1 and BLEU

                total_loss += loss.item() * X.size(0)
                num_examples += X.size(0)
                avg_loss = total_loss / num_examples
                pbar.set_description(f"evaluate average loss {avg_loss:.4f}")
        logger.debug(f"Loss calculation took {t.interval:.3f}s")
        return avg_loss
Example #5
0
    def __init__(
        self,
        path,
        sp,
        min_alternatives=1,
        limit_size=-1,
        max_length=1024,
        subword_regularization_alpha=0.1,
        program_mode="identity",
        preloaded_examples=None,
    ):
        """Create a JSONLinesDataset given a path and field mapping dictionary.
        Arguments:
            path (str): Path to the data file. Must be in .pickle format.
        """
        super().__init__()
        full_path = pathlib.Path(path).resolve()
        if preloaded_examples is not None:
            logger.debug("Using preloaded examples passed via argument")
            self.examples = preloaded_examples
        else:
            logger.debug(f"Loading {full_path}")
            with Timer() as t:
                if str(path).endswith(".gz"):
                    with gzip.open(str(full_path), "rb") as f:
                        self.examples = pickle.load(f)
                else:
                    with full_path.open("rb") as f:
                        self.examples = pickle.load(f)
            logger.debug(
                f"Loaded {len(self.examples)} examples in {t.interval:.3f}s")
        if limit_size > 0:
            self.examples = self.examples[:limit_size]
            logger.debug(f"Limited size: took first {limit_size} examples")
        self.examples = list(map(list, self.examples))
        logger.debug("Converted examples to lists of alternatives")
        if min_alternatives:
            self.examples = list(
                filter(lambda ex: len(ex) >= min_alternatives, self.examples))
        logger.debug(
            f"Filtered dataset to {len(self.examples)} examples with at least {min_alternatives} alternatives"
        )

        self.program_mode = program_mode
        self.max_length = max_length
        self.subword_regularization_alpha = subword_regularization_alpha
        self.sp = sp
        self.bos_id = sp.PieceToId("<s>")
        self.eos_id = sp.PieceToId("</s>")
def _evaluate(model,
              loader,
              sp: spm.SentencePieceProcessor,
              target_to_id,
              use_cuda=True,
              no_output_attention=False):
    model.eval()
    no_type_id = target_to_id["O"]
    any_id = target_to_id["$any$"]

    with torch.no_grad():
        # Accumulate metrics across batches to compute label-wise accuracy
        num1, num5, num_labels_total = 0, 0, 0
        num1_any, num5_any, num_labels_any_total = 0, 0, 0

        with Timer() as t:
            # Compute average loss
            total_loss = 0
            num_examples = 0
            pbar = tqdm.tqdm(loader, desc="evalaute")
            for X, lengths, output_attn, labels in pbar:
                if use_cuda:
                    X, lengths, output_attn, labels = X.cuda(), lengths.cuda(
                    ), output_attn.cuda(), labels.cuda()
                if no_output_attention:
                    logits = model(X, lengths, None)  # BxLxVocab
                else:
                    logits = model(X, lengths, output_attn)  # BxLxVocab
                # Compute loss
                loss = F.cross_entropy(logits.transpose(1, 2),
                                       labels,
                                       ignore_index=no_type_id)

                total_loss += loss.item() * X.size(0)
                num_examples += X.size(0)
                avg_loss = total_loss / num_examples

                # Compute accuracy
                (corr1_any, corr5_any), num_labels_any = accuracy(
                    logits.cpu(),
                    labels.cpu(),
                    topk=(1, 5),
                    ignore_idx=(no_type_id, ))
                num1_any += corr1_any
                num5_any += corr5_any
                num_labels_any_total += num_labels_any

                (corr1, corr5), num_labels = accuracy(logits.cpu(),
                                                      labels.cpu(),
                                                      topk=(1, 5),
                                                      ignore_idx=(no_type_id,
                                                                  any_id))
                num1 += corr1
                num5 += corr5
                num_labels_total += num_labels

                pbar.set_description(
                    f"evaluate average loss {avg_loss:.4f} num1 {num1_any} num_labels_any_total {num_labels_any_total} avg acc1_any {num1_any / (num_labels_any_total + 1e-6) * 100:.4f}"
                )

        # Average accuracies
        acc1 = float(num1) / num_labels_total * 100
        acc5 = float(num5) / num_labels_total * 100
        acc1_any = float(num1_any) / num_labels_any_total * 100
        acc5_any = float(num5_any) / num_labels_any_total * 100

        logger.debug(f"Loss calculation took {t.interval:.3f}s")
        return (
            -acc1_any,
            {
                "eval/loss": avg_loss,
                "eval/acc@1": acc1,
                "eval/acc@5": acc5,
                "eval/num_labels": num_labels_total,
                "eval/acc@1_any": acc1_any,
                "eval/acc@5_any": acc5_any,
                "eval/num_labels_any": num_labels_any_total,
            },
        )
def _evaluate_edit_distance(loaders, sp, pad_id, edit_distance_mode="tokens", save_path=None):
    # TODO: implement adversarial evaluation for edit distance
    # assert len(loaders) == 1
    # loader = loaders[0]

    assert edit_distance_mode == "tokens"
    ray.init()

    @ray.remote
    def _compute_similarity(X, lengths, labels):
        two_B, adversarial_samples, L = X.shape
        B = two_B // 2

        assert X.ndim == 3
        assert lengths.ndim == 2
        assert lengths.shape == (two_B, adversarial_samples)
        assert labels.shape == (B,)

        # Compute similarity
        X = X.view(2, B*adversarial_samples, L)
        lengths = lengths.view(2, B*adversarial_samples)
        similarity = torch.zeros(B*adversarial_samples, dtype=torch.float32)

        for i in range(B*adversarial_samples):
            a_len = lengths[0, i]
            b_len = lengths[1, i]
            a = list(X[0, i].numpy())[:a_len]  # remove padding
            b = list(X[1, i].numpy())[:b_len]  # remove padding
            # a = list(X[0, i].numpy())[1:a_len-1]  # remove bos_id, eos_id and padding
            # b = list(X[1, i].numpy())[1:b_len-1]  # remove bos_id, eos_id and padding
            similarity[i] = _levenshtein_similarity(a, b)  # B

        similarity = similarity.view(B, adversarial_samples)

        # Aggregate adversarially. similarity is [B, adversarial_samples]
        min_similarity, _ = torch.min(similarity, dim=1)
        max_similarity, _ = torch.max(similarity, dim=1)
        similarity = min_similarity * labels + max_similarity * (1 - labels)

        return similarity.numpy()  # [B,]

    with Timer() as t:
        # Compute average loss
        adversarial_samples = len(loaders)
        total_similarity = 0
        num_examples = 0

        y_true = []
        y_scores = []

        generator = _get_generator(loaders, pad_id=pad_id, pbar=True, pbar_desc="queue up evalaute")
        similarity_futures = []
        for X, lengths, labels in generator:
            f = _compute_similarity.remote(X, lengths, labels)
            similarity_futures.append(f)

            num_examples += X.size(1)
            y_true.append(labels.numpy())

        # Aggregate futures, compute ROC AUC, AP
        pbar = tqdm.tqdm(map(ray.get, similarity_futures), desc="get similarities")
        for i, similarity in enumerate(pbar):
            total_similarity += np.sum(similarity)
            avg_similarity = total_similarity / num_examples
            y_scores.append(similarity)

            if i % 10 == 0:
                ytc = np.concatenate(y_true[: i + 1])
                num_pos, num_neg = np.sum(ytc), np.sum(ytc == 0)
                if ytc.sum() != 0 and ytc.sum() < len(ytc):
                    ysc = np.concatenate(y_scores)
                    roc_auc = roc_auc_score(ytc, ysc)
                    se_roc_auc = se_auc(roc_auc, num_pos, num_neg)
                    ap_score = average_precision_score(ytc, ysc)
                else:
                    roc_auc = 0
                    ap_score = 0
                    se_roc_auc = 0
                pbar.set_description(f"evaluate with {adversarial_samples} adversarial samples: avg similarity {avg_similarity:.4f} roc_auc {roc_auc:.4f}pm{se_roc_auc:.4f} ap {ap_score:.4f}")

        # Compute ROC AUC and AP
        y_true = np.concatenate(y_true)
        num_pos, num_neg = np.sum(y_true), np.sum(y_true == 0)
        y_scores = np.concatenate(y_scores)
        roc_auc = roc_auc_score(y_true, y_scores)
        se_roc_auc = se_auc(roc_auc, num_pos, num_neg)
        ap_score = average_precision_score(y_true, y_scores)

    logger.debug(f"Loss calculation took {t.interval:.3f}s")
    metrics = {
        f"eval/roc_auc_score/adv{adversarial_samples}": roc_auc,
        f"eval/se_roc_auc/adv{adversarial_samples}": se_roc_auc,
        f"eval/ap_score/adv{adversarial_samples}": ap_score,
        f"eval/num_examples/adv{adversarial_samples}": num_examples,
        f"eval/num_positive/adv{adversarial_samples}": np.sum(y_true),
    }

    if save_path:
        logger.info("Saving labels, scores and metrics to {}", save_path)
        torch.save({"y_true": y_true, "y_scores": y_scores, "metrics": metrics}, save_path)

    return (-roc_auc, metrics)
def _evaluate(model, loaders, sp: spm.SentencePieceProcessor, pad_id, use_cuda=True, save_path=None):
    model.eval()

    with Timer() as t:
        # Compute ROC AUC, AP
        num_examples_by_adversarial_samples = defaultdict(lambda: 0)
        y_true_by_adversarial_samples = defaultdict(list)
        y_scores_by_adversarial_samples = defaultdict(list)
        generator = _get_generator(loaders, pad_id=pad_id, pbar=True, pbar_desc="evaluate")

        def _summarize_stats(y_true, y_scores):
            ytc = np.concatenate(y_true)
            ysc = np.concatenate(y_scores)
            accuracy = best_binary_accuracy_score(ytc, ysc)
            num_pos, num_neg = np.sum(ytc), np.sum(ytc == 0)

            roc_auc, se_roc_auc, ap_score = 0, 0, 0
            if ytc.sum() != 0 and ytc.sum() < len(ytc):
                roc_auc = roc_auc_score(ytc, ysc)
                se_roc_auc = se_auc(roc_auc, num_pos, num_neg)
                ap_score = average_precision_score(ytc, ysc)

            return roc_auc, se_roc_auc, ap_score, accuracy, num_pos, num_neg

        for X, lengths, labels in generator:
            if use_cuda:
                X, lengths, labels = X.cuda(), lengths.cuda(), labels.cuda()
            two_B, adversarial_samples, L = X.shape

            # Compute similarity
            similarity = model(X.view(two_B*adversarial_samples, L), lengths.view(-1))  # B*adversarial_samples
            assert similarity.ndim == 1

            # Reduce similarity adversarially. Maximize over sliding windows so we can compute
            # metrics for 1, 2, ... adversarial_samples samples of transforms.
            similarity = similarity.view(two_B // 2, adversarial_samples)
            for window_size in range(1, adversarial_samples + 1):
                for start in range(0, adversarial_samples - window_size + 1):
                    window_similarity = similarity[:, start:start+window_size]
                    min_similarity, _ = torch.min(window_similarity, dim=1)
                    max_similarity, _ = torch.max(window_similarity, dim=1)
                    window_similarity = min_similarity * labels + max_similarity * (1 - labels)

                    num_examples_by_adversarial_samples[window_size] += 1
                    y_scores_by_adversarial_samples[window_size].append(window_similarity.cpu().numpy())
                    y_true_by_adversarial_samples[window_size].append(labels.cpu().numpy())

            if num_examples_by_adversarial_samples[adversarial_samples] % 100 == 0:
                for window_size in range(1, adversarial_samples + 1):
                    roc_auc, se_roc_auc, ap_score, accuracy, num_pos, num_neg = _summarize_stats(
                        y_true_by_adversarial_samples[window_size],
                        y_scores_by_adversarial_samples[window_size])
                    print(f"evaluate with {window_size} adversarial samples: roc_auc {roc_auc:.4f}±{se_roc_auc:.4f} (+{num_pos},-{num_neg}) ap {ap_score:.4f} acc {accuracy:.4f}")

    # Compute ROC AUC and AP
    metrics = {}

    for window_size in range(1, adversarial_samples + 1):
        roc_auc, se_roc_auc, ap_score, accuracy, num_pos, num_neg = _summarize_stats(
            y_true_by_adversarial_samples[window_size],
            y_scores_by_adversarial_samples[window_size])

        print(f"evaluate with {window_size} adversarial samples: roc_auc {roc_auc:.4f}±{se_roc_auc:.4f} (+{num_pos},-{num_neg}) ap {ap_score:.4f} acc {accuracy:.4f}")

        metrics.update({
            f"eval/roc_auc_score/adv{window_size}": roc_auc,
            f"eval/se_roc_auc/adv{window_size}": se_roc_auc,
            f"eval/ap_score/adv{window_size}": ap_score,
            f"eval/accuracy/adv{window_size}": accuracy,
            f"eval/num_examples/adv{window_size}": num_examples_by_adversarial_samples[window_size],
            f"eval/num_positive/adv{window_size}": num_pos,
            f"eval/num_negative/adv{window_size}": num_neg,
        })

    logger.debug(f"Loss calculation took {t.interval:.3f}s")

    if save_path:
        logger.info("Saving labels, scores and metrics to {}", save_path)
        torch.save({"y_true": dict(y_true_by_adversarial_samples), "y_scores": dict(y_scores_by_adversarial_samples),
                    "num_examples": dict(num_examples_by_adversarial_samples), "metrics": metrics}, save_path)

    return (-roc_auc, metrics)
Example #9
0
def _evaluate_edit_distance(loader,
                            sp,
                            edit_distance_mode="tokens",
                            save_path=None):
    assert edit_distance_mode == "tokens"
    ray.init()

    @ray.remote
    def _compute_similarity(X, lengths):
        assert X.ndim == 2
        assert lengths.ndim == 1

        # Compute similarity
        X = X.view(2, X.size(0) // 2, X.size(1))
        lengths = lengths.view(2, lengths.size(0) // 2)
        similarity = np.zeros(X.size(1), dtype=np.float32)

        for i in range(X.size(1)):
            a_len = lengths[0, i]
            b_len = lengths[1, i]
            a = list(X[0, i].numpy())[:a_len]  # remove padding
            b = list(X[1, i].numpy())[:b_len]  # remove padding
            # a = list(X[0, i].numpy())[1:a_len-1]  # remove bos_id, eos_id and padding
            # b = list(X[1, i].numpy())[1:b_len-1]  # remove bos_id, eos_id and padding
            similarity[i] = _levenshtein_similarity(a, b)  # B

        return similarity

    with Timer() as t:
        # Compute average loss
        total_similarity = 0
        num_examples = 0

        y_true = []
        y_scores = []

        pbar = tqdm.tqdm(loader, desc="queue up evalaute")
        similarity_futures = []
        for X, lengths, labels in pbar:
            f = _compute_similarity.remote(X, lengths)
            similarity_futures.append(f)

            num_examples += X.size(1)
            y_true.append(labels.numpy())

        # Aggregate futures, compute ROC AUC, AP
        pbar = tqdm.tqdm(map(ray.get, similarity_futures),
                         desc="get similarities")
        for i, similarity in enumerate(pbar):
            total_similarity += np.sum(similarity)
            avg_similarity = total_similarity / num_examples
            y_scores.append(similarity)

            if i % 10 == 0:
                ytc = np.concatenate(y_true[:i + 1])
                if ytc.sum() != 0 and ytc.sum() < len(ytc):
                    ysc = np.concatenate(y_scores)
                    roc_auc = roc_auc_score(ytc, ysc)
                    ap_score = average_precision_score(ytc, ysc)
                else:
                    roc_auc = 0
                    ap_score = 0
                pbar.set_description(
                    f"evaluate average similarity {avg_similarity:.4f} roc_auc {roc_auc:.4f} ap {ap_score:.4f}"
                )

        # Compute ROC AUC and AP
        y_true = np.concatenate(y_true)
        y_scores = np.concatenate(y_scores)
        roc_auc = roc_auc_score(y_true, y_scores)
        ap_score = average_precision_score(y_true, y_scores)

    logger.debug(f"Loss calculation took {t.interval:.3f}s")
    metrics = {
        "eval/roc_auc_score": roc_auc,
        "eval/ap_score": ap_score,
        "eval/num_examples": num_examples,
        "eval/num_positive": np.sum(y_true),
    }

    if save_path:
        logger.info("Saving labels, scores and metrics to {}", save_path)
        torch.save({
            "y_true": y_true,
            "y_scores": y_scores,
            "metrics": metrics
        }, save_path)

    return (-roc_auc, metrics)
Example #10
0
def _evaluate(model,
              loader,
              sp: spm.SentencePieceProcessor,
              use_cuda=True,
              save_path=None):
    model.eval()

    with Timer() as t:
        # Compute average loss
        total_loss = 0
        num_examples = 0
        # Compute ROC AUC, AP
        y_true = []
        y_scores = []
        pbar = tqdm.tqdm(loader, desc="evalaute")
        for X, lengths, labels in pbar:
            y_true.append(labels.numpy())
            if use_cuda:
                X, lengths, labels = X.cuda(), lengths.cuda(), labels.cuda()

            # Compute loss
            similarity = model(X, lengths)  # B
            loss = F.binary_cross_entropy_with_logits(similarity,
                                                      labels.float())

            total_loss += loss.item() * X.size(0)
            num_examples += X.size(0)
            avg_loss = total_loss / num_examples

            y_scores.append(similarity.cpu().numpy())

            ytc = np.concatenate(y_true)
            if ytc.sum() != 0 and ytc.sum() < len(ytc):
                ysc = np.concatenate(y_scores)
                roc_auc = roc_auc_score(ytc, ysc)
                ap_score = average_precision_score(ytc, ysc)
            else:
                roc_auc = 0
                ap_score = 0
            pbar.set_description(
                f"evaluate average loss {avg_loss:.4f} roc_auc {roc_auc:.4f} ap {ap_score:.4f}"
            )

        # Compute ROC AUC and AP
        y_true = np.concatenate(y_true)
        y_scores = np.concatenate(y_scores)
        roc_auc = roc_auc_score(y_true, y_scores)
        ap_score = average_precision_score(y_true, y_scores)

    logger.debug(f"Loss calculation took {t.interval:.3f}s")
    metrics = {
        "eval/loss": avg_loss,
        "eval/roc_auc_score": roc_auc,
        "eval/ap_score": ap_score,
        "eval/num_examples": num_examples,
        "eval/num_positive": np.sum(y_true),
    }

    if save_path:
        logger.info("Saving labels, scores and metrics to {}", save_path)
        torch.save({
            "y_true": y_true,
            "y_scores": y_scores,
            "metrics": metrics
        }, save_path)

    return (-roc_auc, metrics)
Example #11
0
def calculate_f1_metric(
    metric: F1MetricMethodName,
    model,
    test_loader,
    sp: spm.SentencePieceProcessor,
    use_cuda=True,
    use_beam_search=True,
    beam_search_k=10,
    per_node_k=None,
    max_decode_len=20,  # see empirical evaluation of CDF of subwork token lengths
    beam_search_sampler="deterministic",
    top_p_threshold=0.9,
    top_p_temperature=1.0,
    logger_fn=None,
    constrain_decoding=False,
):
    with Timer() as t:
        sample_generations = []
        n_examples = 0
        precision, recall, f1 = 0.0, 0.0, 0.0
        with tqdm.tqdm(test_loader, desc="test") as pbar:
            for X, Y, _, _ in pbar:
                if use_cuda:
                    X, Y = X.cuda(), Y.cuda()  # B, L
                with Timer() as t:
                    # pred, scores = beam_search_decode(model, X, sp, k=beam_search_k, max_decode_len=max_decode_len)
                    if use_beam_search:
                        pred, _ = beam_search_decode(
                            model,
                            X,
                            sp,
                            max_decode_len=max_decode_len,
                            constrain_decoding=constrain_decoding,
                            k=beam_search_k,
                            per_node_k=per_node_k,
                            sampler=beam_search_sampler,
                            top_p_threshold=top_p_threshold,
                            top_p_temperature=top_p_temperature,
                        )
                    else:
                        pred = greedy_decode(model,
                                             X,
                                             sp,
                                             max_decode_len=max_decode_len)
                for i in range(X.size(0)):
                    gt_identifier = ids_to_strs(Y[i], sp)
                    pred_dict = {"gt": gt_identifier}
                    if use_beam_search:
                        top_beam = pred[i][0]
                        tqdm.tqdm.write("{:>20} vs. gt {:<20}".format(
                            pred[i][0], gt_identifier))
                        for i, beam_result in enumerate(pred[i]):
                            pred_dict[f"pred_{i}"] = beam_result
                    else:
                        top_beam = pred[i]
                        pred_dict[f"pred_{i}"] = top_beam
                    sample_generations.append(pred_dict)
                    precision_item, score_item, f1_item = metric(
                        top_beam, gt_identifier)

                    B = X.size(0)
                    n_examples += B
                    precision += precision_item * B
                    precision_avg = precision / n_examples
                    recall += score_item * B
                    recall_avg = recall / n_examples
                    f1 += f1_item * B
                    if precision_avg or recall_avg:
                        f1_overall = 2 * (precision_avg * recall_avg) / (
                            precision_avg + recall_avg)
                    else:
                        f1_overall = 0.0
                    item_metrics = {
                        "precision_item": precision_item,
                        "recall_item": score_item,
                        "f1_item": f1_item
                    }
                    avg_metrics = {
                        "precision_avg": precision_avg,
                        "recall_avg": recall_avg,
                        "f1_avg": f1 / n_examples,
                        "f1_overall": f1_overall,
                    }
                    pbar.set_postfix(avg_metrics)
                    if logger_fn is not None:
                        logger_fn(item_metrics)
                        logger_fn(avg_metrics)
    logger.debug(
        f"Test set evaluation (F1) took {t.interval:.3}s over {n_examples} samples"
    )
    precision_avg = precision / n_examples
    recall_avg = recall / n_examples
    f1_overall = 2 * (precision_avg * recall_avg) / (precision_avg +
                                                     recall_avg)
    return precision_avg, recall_avg, f1_overall, sample_generations