コード例 #1
0
    def _test(average=None):

        y_true, y_pred = get_y_true_y_pred()
        th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

        true_res = [0, 0, 0]
        for index in range(3):
            bin_y_true = y_true == index
            bin_y_pred = y_pred == index
            intersection = bin_y_true & bin_y_pred
            union = bin_y_true | bin_y_pred
            true_res[index] = intersection.sum() / union.sum()

        cm = ConfusionMatrix(num_classes=3, average=average)
        iou_metric = IoU(cm)

        # Update metric
        output = (th_y_logits, th_y_true)
        cm.update(output)

        res = iou_metric.compute().numpy()

        assert np.all(res == true_res)

        for ignore_index in range(3):
            cm = ConfusionMatrix(num_classes=3)
            iou_metric = IoU(cm, ignore_index=ignore_index)
            # Update metric
            output = (th_y_logits, th_y_true)
            cm.update(output)
            res = iou_metric.compute().numpy()
            true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
            assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"
コード例 #2
0
def evaluate(model: AdapNet, dl: DataLoader, mode, batch_size=2):
    """
    Evaluates the model, uses IoU as the metric
    :param model: The model to evaluate
    :param dl: The DataLoader of the model
    :param mode: The evaluations mode, one of "test" or "validation"
    :param batch_size: The batch size for the evaluation
    :return:
    """
    model.eval()

    if mode == "test":
        set = dl.test_set
    else:
        set = dl.validation_set

    reps = len(set) // batch_size
    cm = ConfusionMatrix(dl.num_labels)
    iou_cur = IoU(cm)

    with torch.no_grad():
        for _ in range(reps):
            m1, m2, gt = dl.sample_batch(batch_size, mode=mode)
            _, _, res = model(m1, m2)
            res = torch.softmax(res, dim=1)
            cm.update((res, gt))

    iou_score = iou_cur.compute()

    print("Evaluation of " + mode + " set")
    print("mIoU: " + str(iou_score.mean().item()))
    print("IoU: " + str(iou_score))

    return iou_score