Esempio n. 1
0
    def convert(self, kind="user"):
        """Convert underlying metric objects.

        Conversion to user format returns a dictionary with each element mapping
        metric name to metric value. Conversion to db format returns a
        list of dictionaries, each with keys "name", "scoring", and "value"
        mapping to their respective values. Both formats convert np.floating
        values to Python floats.

        Parameters
        ----------
        kind : str
            One of "user" or "db"
        """
        if kind=="user":
            metrics = {}
            for m in self._list:
                metrics.update(m.convert(kind="user"))
        elif kind=="db":
            metrics = []
            for m in self._list:
                metrics.append(m.convert(kind="db"))
        else:
            ValueError("Bad kind: {}".format(kind))

        return metrics
def log_run(split: str, epoch: int, writer: tf.summary.SummaryWriter,
            label_names: Sequence[str], metrics: MutableMapping[str, float],
            heaps: Mapping[str,
                           Mapping[int,
                                   List[HeapItem]]], cm: np.ndarray) -> None:
    """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
    single epoch run to Tensorboard.

    Args:
        metrics: dict, keys already prefixed with {split}/
    """
    per_class_recall = recall_from_confusion_matrix(cm, label_names)
    metrics.update(prefix_all_keys(per_class_recall, f'{split}/label_recall/'))

    # log metrics
    for metric, value in metrics.items():
        tf.summary.scalar(metric, value, epoch)

    # log confusion matrix
    cm_fig = plot_utils.plot_confusion_matrix(cm,
                                              classes=label_names,
                                              normalize=True)
    cm_fig_img = tf.convert_to_tensor(fig_to_img(cm_fig)[np.newaxis, ...])
    tf.summary.image(f'confusion_matrix/{split}', cm_fig_img, step=epoch)

    # log tp/fp/fn images
    for heap_type, heap_dict in heaps.items():
        log_images_with_confidence(heap_dict,
                                   label_names,
                                   epoch=epoch,
                                   tag=f'{split}/{heap_type}')
    writer.flush()
Esempio n. 3
0
def log_run(split: str, epoch: int, writer: tensorboard.SummaryWriter,
            label_names: Sequence[str], metrics: MutableMapping[str, float],
            heaps: Optional[Mapping[str, Mapping[int, list[HeapItem]]]],
            cm: np.ndarray) -> None:
    """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
    single epoch run to Tensorboard.

    Args:
        metrics: dict, keys already prefixed with {split}/
    """
    per_label_recall = recall_from_confusion_matrix(cm, label_names)
    metrics.update(prefix_all_keys(per_label_recall, f'{split}/label_recall/'))

    # log metrics
    for metric, value in metrics.items():
        writer.add_scalar(metric, value, epoch)

    # log confusion matrix
    cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names,
                                              normalize=True)
    cm_fig_img = fig_to_img(cm_fig)
    writer.add_image(tag=f'confusion_matrix/{split}', img_tensor=cm_fig_img,
                     global_step=epoch, dataformats='HWC')

    # log tp/fp/fn images
    if heaps is not None:
        for heap_type, heap_dict in heaps.items():
            log_images_with_confidence(writer, heap_dict, label_names,
                                       epoch=epoch, tag=f'{split}/{heap_type}')
    writer.flush()
Esempio n. 4
0
def _evaluate_predictions(ctx, pipeline, X, y):
    outputs = pipeline._predict(X)

    if ctx.is_regression:
        metrics = _evaluate_regressor(ctx, y, outputs.predictions)
        metrics.update(_extra_regressor_metrics(ctx, pipeline, metrics,
                                                len(y)))
    else:
        encoder = pipeline.label_encoder
        # MAY: Eventually use the probabilities, too.
        predicted = encoder.transform(outputs.predictions)
        metrics = _evaluate_classifier(ctx, y, predicted, encoder)

    return MetricsReport(outputs.predictions, outputs.probabilities, metrics)
def rank_classification(targets, predictions, num_classes=2):
    """Computes standard metrics classification based on log likelihood ranking.

  This metric is intended to be used along with the `rank_classification`
  preprocessor and postprocessor. Each example is scored (by log likelihood)
  for every possible label, and the label with the best score is selected as the
  prediction.

  Args:
    targets: list of int, the true label value for eached aligned "prediction"
      score.
    predictions: list of float, a flat list of log likelihood scores for each
      possible label for each example.
    num_classes: int, the number of possible classes for the label.
  Returns:
    Accuracy, f1, and AUC scores.
  """
    assert len(targets) == len(predictions)
    assert len(targets) % num_classes == 0

    labels = np.array(targets[::num_classes])
    labels_onehot = np.eye(num_classes)[labels]

    log_likelihoods = np.array(predictions, np.float32).reshape(
        (-1, num_classes))
    predictions = log_likelihoods.argmax(-1)

    def exp_normalize(x):
        b = x.max(-1)[:, np.newaxis]
        y = np.exp(x - b)
        return y / y.sum(-1)[:, np.newaxis]

    probs = exp_normalize(log_likelihoods)

    if num_classes > 2:
        metrics = mean_multiclass_f1(num_classes)(labels, predictions)
    else:
        metrics = {"f1": 100 * sklearn.metrics.f1_score(labels, predictions)}
    metrics.update({
        "auc-roc":
        100 *
        sklearn.metrics.roc_auc_score(labels_onehot, probs, multi_class="ovr"),
        "auc-pr":
        100 * sklearn.metrics.average_precision_score(labels_onehot, probs),
        "accuracy":
        100 * sklearn.metrics.accuracy_score(labels, predictions),
    })
    return metrics
def get_default_characteristics(auto, X, y):
    metrics = {}
    best_model = get_best_model(auto)
    metrics["Dataset"] = get_dataset(auto)
    metrics["Best-Model"] = get_best_model_name(auto)
    metrics["Hyperparameters"] = get_sorted_params(auto)
    metrics["Selection_alg"] = get_hpo(auto)
    best_model = SimpleClassificationPipeline(best_model)
    if auto.ensemble_size != 0:
        metrics.update(
            prefix_dict_keys(
                "Ensemble",
                get_cross_validation_metrics(auto, X, y, DEFAULT_METRICS)))
    metrics.update(
        prefix_dict_keys(
            "Best_model",
            get_cross_validation_metrics(best_model, X, y, DEFAULT_METRICS)))
    return metrics
Esempio n. 7
0
def rank_classification(
        targets: Sequence[Tuple[int, bool, float]],
        scores: Sequence[float],
        num_classes: Optional[int] = None) -> Dict[str, Union[float, int]]:
    """Computes standard metrics classification based on log likelihood ranking.

  This metric is intended to be used along with the `rank_classification`
  preprocessor and postprocessor. Each example is scored (by log likelihood)
  for every possible label, and the label with the best score is selected as the
  prediction.

  In the case of multiple labels, a prediction matching any will be considered
  correct.

  Args:
    targets: list of tuples, the 'idx', 'is_correct' and 'weight' fields from
      ground truth examples.
    scores: list of float, a flat list of log likelihood scores for each
      possible label for each example.
    num_classes: int or None, the number of possible classes for the label or
      None if the number of classes vary.

  Returns:
    Accuracy, f1, and AUC scores.

  Raises:
    ValueError: if `targets` is not a sequence of 3-tuples.
  """
    assert len(targets) == len(scores)
    if len(targets[0]) != 3:
        raise ValueError(
            "`targets` should contain three elements. Only %d are provided." %
            len(targets[0]))

    if not num_classes:
        # Assuming variable classes. Can only compute accuracy.
        num_correct = 0
        total = 0
        for _, grp in itertools.groupby(zip(targets, scores),
                                        lambda x: x[0][0]):
            exs, log_likelihoods = zip(*grp)
            prediction = np.argmax(log_likelihoods)
            weights = exs[prediction][2]
            num_correct += exs[prediction][1] * weights
            total += weights
        return {"accuracy": 100 * num_correct / total}

    assert len(targets) % num_classes == 0
    labels_indicator = np.array([is_correct
                                 for _, is_correct, _ in targets]).reshape(
                                     (-1, num_classes))
    weights = np.array([weight for _, _, weight in targets]).reshape(
        (-1, num_classes))[:, 0]
    log_likelihoods = np.array(scores, np.float32).reshape((-1, num_classes))
    predictions = log_likelihoods.argmax(-1)

    if np.any(labels_indicator.sum(axis=-1) > 1):
        # multiple-answer case
        logging.info(
            "Multiple labels detected. Predictions matching any label will be "
            "considered correct.")
        num_examples = len(labels_indicator)
        return {
            "accuracy":
            (100 *
             np.average(labels_indicator[np.arange(num_examples), predictions],
                        weights=weights))
        }

    predictions_indicator = np.eye(num_classes)[predictions]

    def exp_normalize(x):
        b = x.max(-1)[:, np.newaxis]
        y = np.exp(x - b)
        return y / y.sum(-1)[:, np.newaxis]

    probs = exp_normalize(log_likelihoods)

    if num_classes > 2:
        metrics = mean_multiclass_f1(num_classes, sample_weight=weights)(
            labels_indicator, predictions_indicator)
    else:
        metrics = {
            "f1":
            100 * sklearn.metrics.f1_score(labels_indicator.argmax(-1),
                                           predictions,
                                           sample_weight=weights)
        }
    metrics.update({
        "auc-roc":
        100 * sklearn.metrics.roc_auc_score(
            labels_indicator, probs, multi_class="ovr", sample_weight=weights),
        "auc-pr":
        100 * sklearn.metrics.average_precision_score(
            labels_indicator, probs, sample_weight=weights),
        "accuracy":
        100 * sklearn.metrics.accuracy_score(
            labels_indicator, predictions_indicator, sample_weight=weights),
    })
    return metrics
Esempio n. 8
0
    def compute_classification_metrics(self,
                                       labels,
                                       pred,
                                       output=None,
                                       to_print=True,
                                       epoch=None,
                                       is_final_model=False):
        target = labels.numpy().reshape(-1).copy()
        target[target == -1] = 0
        pred = pred.cpu().numpy().reshape(-1)
        pred[pred == -1] = 0
        accuracy = sklearn.metrics.accuracy_score(target, pred)
        self.test_accuracies.append(accuracy)
        # If pred is binary input this is correct
        balanced_accuracy =\
            sklearn.metrics.roc_auc_score(target, pred)
        if output is not None:
            output = output.cpu().numpy().reshape(-1)
            auc_score = sklearn.metrics.roc_auc_score(target, output)
        # Return result for each class
        precision, recall, f1_score, _ =\
            sklearn.metrics.precision_recall_fscore_support(target, pred)
        n_negative = np.sum(target == 0)
        n_false_positive = np.sum(np.logical_and(pred == 1, target == 0))

        prefix = "heur_" if is_final_model else ""
        # First add the metrics as they were computed in the original code.
        metrics = {
            f"{prefix}orig_test_acc":
            accuracy,
            f"{prefix}orig_test_tpr":
            np.sum(np.logical_and(pred == 1, target == 1)) /
            np.sum(target == 1),
            f"{prefix}orig_test_tnr":
            np.sum(np.logical_and(pred == 0, target == 0)) /
            np.sum(target == 0),
            f"{prefix}orig_auroc":
            auc_score,
        }

        # Then we add our own metrics. For the sake of consistency with the other OOD detection methods, we will label
        # the training set as negative and the outliers as positives (so the opposite of how PU learning labels them).
        our_metrics = lib_data.get_eval_metrics(
            id_y_true=1 - target[target == 1],
            id_y_pred=1 - pred[target == 1],
            id_test_statistic=1 - output[target == 1],
            ood_y_true=1 - target[target == 0],
            ood_y_pred=1 - pred[target == 0],
            ood_test_statistic=1 - output[target == 0],
        )
        if is_final_model:
            our_metrics = {f"heur_{k}": v for k, v in our_metrics.items()}
        metrics.update(our_metrics)

        lib_data.retry(lambda: mlflow.log_metrics(metrics, step=epoch))

        if to_print:
            print('Test set: Accuracy: {:.2f}%'.format(accuracy * 100),
                  flush=True)
            print('Test set: Balanced Accuracy: {:.2f}%'.format(
                balanced_accuracy * 100),
                  flush=True)
            if output is not None:
                print('Test set: Auc Score: {:.2f}%'.format(auc_score * 100),
                      flush=True)
            print('Test set: Precision: {:.2f}%'.format(precision[1] * 100),
                  flush=True)
            print('Test set: Recall Score: {:.2f}%'.format(recall[1] * 100),
                  flush=True)
            print('Test set: F1 Score: {:.2f}%'.format(f1_score[1] * 100),
                  flush=True)
            print('Test set: False Positive Rate: {:.2f}%'.format(
                n_false_positive / n_negative * 100),
                  flush=True)
Esempio n. 9
0
                #'level':level,
                'all validation samples':all_validation_samples,
                'misclassified images':false_positives+false_negatives,
                'false negatives':false_negatives,
                'false positives':false_positives,
                'rejected noise (%)':100*true_negatives/all_real_noises, # specificity
                'false alarms (%)':100*false_positives/all_predicted_as_signals, #all_real_noises, # TODO capire meglio!!! (dovrebbe essere 1 - purity) # "false alarm rate"
                'missed signals (%)':100*false_negatives/all_real_signals, # TODO il false dismissal è 1-efficiency ?
                'selected signals (%)':100*true_positives/all_real_signals,
                'purity (%)':100*purity,
                'efficiency (%)':100*efficiency,
                'accuracy (%)':100*accuracy}
     return metrics
 
 metrics = compute_the_metrics(confusion_matrix)
 metrics.update({'SNR':signal_to_noise_ratio, 'level':level})
 results = pandas.DataFrame(metrics, index=[signal_to_noise_ratio])
 results.to_csv('/storage/users/Muciaccia/burst/models/results_SNR_{}.csv'.format(signal_to_noise_ratio), index=False)
 # TODO poi concatenarli con pandas.concat()
 # TODO magari levare le percentuali e rimettere le quantità normalizzate
 
 # NOTA: ai fini della scoperta con 5 sigma di confidenza bisogna guardare il valore di purezza (o di false alarm)
 
 # a mio avviso si sviluppa un leggero overfitting ad SNR 15 ed SNR 10
 
 # TODO provare a diminiure gradualmente sia il dropout che il learning rate
 
 ######################
 
 # plot the output histogram
 
def rank_classification(
    targets: Sequence[Tuple[Sequence[int], bool, float, int]],
    scores: Sequence[float],
    num_classes: Optional[int] = None,
    normalize_by_target_length: bool = False,
) -> Dict[str, Union[float, int]]:
  """Computes standard metrics classification based on log likelihood ranking.

  This metric is intended to be used along with the `rank_classification`
  preprocessor and postprocessor. Each example is scored (by log likelihood)
  for every possible label, and the label with the best score is selected as the
  prediction.

  In the case of multiple labels, a prediction matching any will be considered
  correct.

  For problems with two labels, AUC-pr and AUC-roc retrieval metrics will be
  reported for the positive class, which is assumed to have an 'idx' of 1. If
  more labels are present, only accuracy and F-1 will be reported.

  Args:
    targets: list of tuples, the 'idx', 'is_correct', 'weight' fields, and
      length of target tokens from ground truth examples.
    scores: list of float, a flat list of log likelihood scores for each
      possible label for each example.
    num_classes: int or None, the number of possible classes for the label or
      None if the number of classes vary.
    normalize_by_target_length: bool, if True the scores are normalized by the
      target token lengths.
  Returns:
    Accuracy, f1, and AUC scores.

  Raises:
    ValueError: if `targets` is not a sequence of 3-tuples.
  """
  assert len(targets) == len(scores)
  if len(targets[0]) != 4:
    raise ValueError(
        f"`targets` should contain 4 elements but has {len(targets[0])}.")

  normalized_scores = []
  if normalize_by_target_length:
    for target, score in zip(targets, scores):
      _, _, _, target_length = target
      score = score / target_length
      normalized_scores.append(score)

    scores = normalized_scores

  idx_0 = targets[0][0]
  if not hasattr(idx_0, "__len__") or len(idx_0) != 2:
    raise ValueError(
        "The first element of `targets` ('idx') should be 2-dimensional. "
        f"Got {idx_0}.")

  # Sort by 'idx' since the function relies on this assumption.
  # ((idx, is_correct, weight), score)
  get_idx = lambda x: x[0][0]
  targets, scores = zip(*sorted(zip(targets, scores), key=get_idx))

  if not num_classes:
    # Assuming variable classes. Can only compute accuracy.
    num_correct = 0
    total = 0

    # (((input idx, output idx), is_correct, weight), score)
    get_grp = lambda x: x[0][0][0]

    for _, grp in itertools.groupby(zip(targets, scores), get_grp):
      exs, log_likelihoods = zip(*grp)
      prediction = np.argmax(log_likelihoods)
      weights = exs[prediction][2]
      num_correct += exs[prediction][1] * weights
      total += weights
    return {"accuracy": 100 * num_correct / total}

  assert len(targets) % num_classes == 0, f"{len(targets)} % {num_classes} != 0"

  labels_indicator = np.array([is_correct for _, is_correct, _, _ in targets
                              ]).reshape((-1, num_classes))
  weights = np.array([weight for _, _, weight, _ in targets]).reshape(
      (-1, num_classes))[:, 0]
  log_likelihoods = np.array(scores, np.float32).reshape((-1, num_classes))
  predictions = log_likelihoods.argmax(-1)

  if np.any(labels_indicator.sum(axis=-1) > 1):
    # multiple-answer case
    logging.info(
        "Multiple labels detected. Predictions matching any label will be "
        "considered correct.")
    num_examples = len(labels_indicator)
    return {
        "accuracy": (100 * np.average(
            labels_indicator[np.arange(num_examples), predictions],
            weights=weights))
    }

  predictions_indicator = np.eye(num_classes)[predictions]

  def exp_normalize(x):
    b = x.max(-1)[:, np.newaxis]
    y = np.exp(x - b)
    return y / y.sum(-1)[:, np.newaxis]
  probs = exp_normalize(log_likelihoods)

  metrics = {
      "accuracy":
          100 * sklearn.metrics.accuracy_score(
              labels_indicator, predictions_indicator, sample_weight=weights),
  }

  if num_classes > 2:
    metrics.update(
        mean_multiclass_f1(num_classes,
                           sample_weight=weights)(labels_indicator,
                                                  predictions_indicator))
    logging.warning("AUC-pr and AUC-roc are not supported when num_classes > 2")
  else:
    metrics.update({
        "f1":
            100 * sklearn.metrics.f1_score(
                labels_indicator.argmax(-1), predictions, sample_weight=weights)
    })
    labels_indicator = labels_indicator[:, 1]
    probs = probs[:, 1]

    metrics.update({
        "auc-roc":
            100 * sklearn.metrics.roc_auc_score(
                labels_indicator, probs, multi_class="ovr",
                sample_weight=weights, average="macro"),
        "auc-pr":
            100 * sklearn.metrics.average_precision_score(
                labels_indicator, probs, sample_weight=weights,
                average="macro"),
    })

  return metrics
Esempio n. 11
0
def evaluate_from_labeled_token_spans(truth_labeled_token_spans,
                                      predicted_labeled_token_spans, tokens):
    """Evaluate predicted spans against ground truth by span type.

  Args:
    truth_labeled_token_spans: List of LabeledTokenSpan for ground truth spans.
    predicted_labeled_token_spans: List of LabeledTokenSpan for predicted spans.
    tokens: List of token objects.

  Returns:
    Metrics dict by span type and evaluation view.
  """
    metrics = {}

    def labeled_token_span_size_nonspaces(labeled_token_span):
        return ap_parsing_utils.token_span_size_nonspaces(
            (labeled_token_span.start_token, labeled_token_span.end_token),
            tokens=tokens)

    # Calculate token labels:
    truth_token_labels = np.zeros(len(tokens), dtype=np.int64)
    pred_token_labels = np.zeros(len(tokens), dtype=np.int64)
    token_mask = np.array([
        token.token_type != tokenizer_lib.TokenType.SPACE for token in tokens
    ])

    for spans, labels in [(truth_labeled_token_spans, truth_token_labels),
                          (predicted_labeled_token_spans, pred_token_labels)]:
        for span in spans:
            labels[span.start_token:span.end_token] = span.span_type.value

    # Calculate for every span type (excluding UNKNOWN):
    for span_type in list(ap_parsing_lib.LabeledSpanType)[1:]:
        metrics.update(
            token_level_metrics(truth_token_labels, pred_token_labels,
                                token_mask, span_type))

        # Calculate span level metrics:
        cur_truth_labeled_token_spans = sorted_labeled_token_spans_by_type(
            truth_labeled_token_spans, span_type=span_type)
        cur_predicted_labeled_token_spans = sorted_labeled_token_spans_by_type(
            predicted_labeled_token_spans, span_type=span_type)

        truth_token_overlaps = _calculate_overlaps(
            cur_truth_labeled_token_spans, cur_predicted_labeled_token_spans,
            tokens)
        truth_token_span_sizes = list(
            map(labeled_token_span_size_nonspaces,
                cur_truth_labeled_token_spans))
        predicted_token_span_sizes = list(
            map(labeled_token_span_size_nonspaces,
                cur_predicted_labeled_token_spans))

        metrics.update(
            span_level_metrics(truth_token_overlaps, truth_token_span_sizes,
                               predicted_token_span_sizes, span_type))

        # Calculate action item type metrics:
        if span_type == ap_parsing_lib.LabeledSpanType.ACTION_ITEM:
            metrics.update(
                _get_action_item_metrics(truth_token_overlaps,
                                         cur_truth_labeled_token_spans,
                                         cur_predicted_labeled_token_spans))
    return metrics
Esempio n. 12
0
thresholds = [0.8, 0.85, 0.9, 0.95]
results = []
for weights in ["balanced", "unbalanced"]:
    reg = DecisionTreeClassifier(
        max_depth=6,
        min_samples_split=20,
        min_samples_leaf=5,
        random_state=50,
        class_weight=weights if weights == "balanced" else None).fit(
            X_train, y_train)
    train_metrics = show_metrics(y_train, reg.predict_proba(X_train),
                                 thresholds, "train", weights, "results")
    [
        metrics.update({
            "type": weights,
            "data": "train"
        }) for metrics in train_metrics
    ]
    test_metrics = show_metrics(y_test, reg.predict_proba(X_test), thresholds,
                                "test", weights, "results")
    [
        metrics.update({
            "type": weights,
            "data": "test"
        }) for metrics in test_metrics
    ]

    results += train_metrics + test_metrics

pd.DataFrame(results).to_csv("results/results.csv", index=False)
Esempio n. 13
0
def eval_epoch(model, data_loader, epoch, config, suffix=""):
    metrics = {
        "images": Concat(),
        "targets": Concat(),
        "logits": Concat(),
        "loss": Concat(),
    }

    # loop over batches ################################################################################################
    model.eval()
    with torch.no_grad():
        for images, meta, targets in tqdm(
            data_loader,
            desc="fold {}, epoch {}/{}, eval".format(config.fold, epoch, config.train.epochs),
        ):
            images, meta, targets = (
                images.to(DEVICE),
                {k: meta[k].to(DEVICE) for k in meta},
                targets.to(DEVICE),
            )

            logits = model(images, meta)
            loss = compute_loss(input=logits, target=targets, config=config)

            metrics["images"].update(images.data.cpu())
            metrics["targets"].update(targets.data.cpu())
            metrics["logits"].update(logits.data.cpu())
            metrics["loss"].update(loss.data.cpu())

    # compute metrics ##################################################################################################
    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        metrics.update(compute_metric(input=metrics["logits"], target=metrics["targets"]))
        images_hard_pos = topk_hardest(
            metrics["images"],
            metrics["loss"],
            metrics["targets"] > 0.5,
            topk=config.eval.batch_size,
        )
        images_hard_neg = topk_hardest(
            metrics["images"],
            metrics["loss"],
            metrics["targets"] <= 0.5,
            topk=config.eval.batch_size,
        )
        roc_curve = plot_roc_curve(input=metrics["logits"], target=metrics["targets"])
        metrics["loss"] = metrics["loss"].mean()

        writer = SummaryWriter(os.path.join(config.experiment_path, "eval", suffix))
        writer.add_image(
            "images/hard/pos",
            torchvision.utils.make_grid(
                images_hard_pos, nrow=compute_nrow(images_hard_pos), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "images/hard/neg",
            torchvision.utils.make_grid(
                images_hard_neg, nrow=compute_nrow(images_hard_neg), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_scalar("loss", metrics["loss"], global_step=epoch)
        writer.add_scalar("roc_auc", metrics["roc_auc"], global_step=epoch)
        writer.add_figure("roc_curve", roc_curve, global_step=epoch)

        writer.flush()
        writer.close()

    return metrics["roc_auc"]