예제 #1
0
def perform_evaluation(classifier: str):
    """Function to set up the CLI for the evaluation script.

    Args:
        classifier (str): the classifier
    """
    parser = argparse.ArgumentParser(description="Evaluate trained models.")
    parser.add_argument("--model", help=f"the model to evaluate (if unspecified, all models in 'runs://{classifier}' will be evaluated)",
                        type=str, default=None)
    parser.add_argument("--dataset", help="the dataset to evaluate (if unspecified, train and val will be evaluated)",
                        type=str, default=None, choices=[x.value for x in Datasets])
    parser.add_argument("--out", help="output folder", type=str,
                        default=f"results://{classifier}")
    parser.add_argument("--find-mistakes", help="whether to output all misclassification images",
                        dest="find_mistakes", action="store_true")
    parser.set_defaults(find_mistakes=False)
    args = parser.parse_args()

    # Evaluate
    output_folder = URI(args.out)
    output_folder.mkdir(parents=True, exist_ok=True)
    output_csv = output_folder / "evaluate.csv"
    with output_csv.open("w") as f:
        models = list(URI(f"runs://{classifier}").glob("*/*.pt")) \
            if args.model is None else [URI(args.model)]
        datasets = [Datasets.TRAIN, Datasets.VAL] \
            if args.dataset is None else [d for d in Datasets if d.value == args.dataset]
        for i, model in enumerate(models):
            logger.info(f"Processing model {i+1}/{len(models)}")
            f.write(evaluate(model, datasets, output_folder,
                             find_mistakes=args.find_mistakes,
                             include_heading=i == 0) + "\n")
예제 #2
0
    parser.add_argument(
        "--dataset",
        help=
        "the dataset to evaluate (if unspecified, train and val will be evaluated)",
        type=str,
        default=None,
        choices=[x.value for x in Datasets])
    parser.add_argument("--out",
                        help="output folder",
                        type=str,
                        default=f"results://recognition")
    parser.add_argument("--save-fens",
                        help="store predicted and actual FEN strings",
                        action="store_true",
                        dest="save_fens")
    parser.set_defaults(save_fens=False)
    args = parser.parse_args()
    output_folder = URI(args.out)
    output_folder.mkdir(parents=True, exist_ok=True)

    datasets = [Datasets.TRAIN, Datasets.VAL] \
        if args.dataset is None else [d for d in Datasets if d.value == args.dataset]

    recognizer = TimedChessRecognizer()

    for dataset in datasets:
        folder = URI("data://render") / dataset.value
        logger.info(f"Evaluating dataset {folder}")
        with (output_folder / f"{dataset.value}.csv").open("w") as f:
            evaluate(recognizer, f, folder, save_fens=args.save_fens)