def _evaluate(cfg: CN, dataset: Datasets, output_folder: Path, find_mistakes: bool = False, include_heading: bool = False) -> str: mistakes = 0 total = 0 folder = URI("data://render") / dataset.value for img_file in folder.glob("*.png"): total += 1 img = cv2.imread(str(img_file)) json_file = folder / f"{img_file.stem}.json" with json_file.open("r") as f: label = json.load(f) actual = np.array(label["corners"]) try: predicted = find_corners(cfg, img) except Exception: predicted = None if predicted is not None: actual = sort_corner_points(actual) predicted = sort_corner_points(predicted) if predicted is None or np.linalg.norm(actual - predicted, axis=-1).max() > 10.: mistakes += 1 return mistakes, total
def train_classifier(name: str): """Set up CLI interface for training a classifier. Args: name (str): the name of the classifier (`"occupancy_classifier"` or `"piece_classifier"`) """ configs_dir = URI("config://") / name def _train(config: str): cfg = CN.load_yaml_with_base(configs_dir / f"{config}.yaml") run_dir = URI("runs://") / name / config # Train the model and save it train(cfg, run_dir) # Read available configs configs = [x.stem for x in configs_dir.glob("*.yaml") if not x.stem.startswith("_")] # Set up argument parser parser = argparse.ArgumentParser(description="Train the network.") parser.add_argument("--config", help="the configuration to train (default: all)", type=str, choices=configs, default=None) args = parser.parse_args() # Train if args.config is None: logger.info("Training all configurations one by one") for config in configs: _train(config) else: logger.info(f"Training the {args.config} configuration") _train(args.config)
def main(classifiers_folder: Path = URI("models://"), setup: callable = lambda: None): """Main method for running inference from the command line. Args: classifiers_folder (Path, optional): the path to the classifiers (supplying a different path is especially useful because the transfer learning classifiers are located at ``models://transfer_learning``). Defaults to ``models://``. setup (callable, optional): An optional setup function to be called after the CLI argument parser has been setup. Defaults to lambda:None. """ parser = argparse.ArgumentParser( description="Run the chess recognition pipeline on an input image") parser.add_argument("file", help="path to the input image", type=str) parser.add_argument( "--white", help="indicate that the image is from the white player's perspective (default)", action="store_true", dest="color") parser.add_argument( "--black", help="indicate that the image is from the black player's perspective", action="store_false", dest="color") parser.set_defaults(color=True) args = parser.parse_args() setup() img = cv2.imread(str(URI(args.file))) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) recognizer = ChessRecognizer(classifiers_folder) board, *_ = recognizer.predict(img, args.color) print(board) print() print( f"You can view this position at https://lichess.org/editor/{board.board_fen()}") if board.status() != Status.VALID: print() print("WARNING: The predicted chess position is not legal according to the rules of chess.") print(" You might want to try again with another picture.")
def _train_model(model_type: str) -> typing.Tuple[torch.nn.Module, CN]: model_file = next((URI("models://") / model_type).glob("*.pt")) yaml_file = URI("config://transfer_learning") / \ model_type / f"{model_file.stem}.yaml" cfg = CN.load_yaml_with_base(yaml_file) run_dir = URI("runs://transfer_learning") / model_type model = torch.load(model_file, map_location=DEVICE) model = device(model) is_inception = "inception" in model_file.stem.lower() train_model(cfg, run_dir, model, is_inception, model_file.stem, eval_on_train=True)
def download_zip_folder(url: str, destination: os.PathLike, show_size: bool = False, skip_if_exists: bool = True): """Download and extract a ZIP folder from a URL. The file is first downloaded to a temporary location and then extracted to the target folder. Args: url (str): the URL of the ZIP file destination (os.PathLike): the destination folder show_size (bool, optional): whether to display a progress bar. Defaults to False. skip_if_exists (bool, optional): if true, will do nothing when the destination path exists already. Defaults to True. """ destination = URI(destination) if skip_if_exists and destination.exists(): logger.info( f"Not downloading {url} to {destination} again because it already exists" ) return with tempfile.TemporaryDirectory() as tmp_dir: zip_file = Path(tmp_dir) / f"{destination.name}.zip" download_file(url, zip_file, show_size) logger.info(f"Unzipping {zip_file} to {destination}") shutil.rmtree(destination, ignore_errors=True) with zipfile.ZipFile(zip_file, "r") as f: f.extractall(destination, _get_members(f)) logger.info(f"Finished downloading {url} to {destination}")
def download_zip_folder_from_google_drive(file_id: str, destination: os.PathLike, show_size: bool = False, skip_if_exists: bool = True): """Download and extract a ZIP file from Google Drive. Args: file_id (str): the Google Drive file ID destination (os.PathLike): the destination folder show_size (bool, optional): whether to display a progress bar. Defaults to False. skip_if_exists (bool, optional): if true, will do nothing when the destination path exists already. Defaults to True. """ destination = URI(destination) if skip_if_exists and destination.exists(): logger.info( f"Not downloading {file_id} to {destination} again because it already exists" ) return with tempfile.TemporaryDirectory() as tmp_dir: zip_file = Path(tmp_dir) / f"{destination.name}.zip" logger.info(f"Downloading {file_id} to {zip_file}") gdd.download_file_from_google_drive(file_id=file_id, dest_path=zip_file, overwrite=True, showsize=show_size) logger.info(f"Unzipping {zip_file} to {destination}") shutil.rmtree(destination, ignore_errors=True) with zipfile.ZipFile(zip_file, "r") as f: f.extractall(destination, _get_members(f)) logger.info(f"Finished downloading {file_id} to {destination}")
def create_configs(classifier: str, include_centercrop: bool = False): """Create the YAML configuration files for all registered models for a classifier. Args: classifier (str): the classifier (either `"occupancy_classifier"` or `"piece_classifier"`) include_centercrop (bool, optional): whether to create two configs per model, one including center crop and one not. Defaults to False. """ config_dir = URI("config://") / classifier logger.info(f"Removing YAML files from {config_dir}.") for f in config_dir.glob("*.yaml"): if not f.name.startswith("_"): f.unlink() for name, model in MODELS_REGISTRY[classifier.upper()].items(): for center_crop in ({True, False} if include_centercrop else {False}): config_file = config_dir / \ (name + ("_centercrop" if center_crop else "") + ".yaml") logging.info(f"Writing configuration file {config_file}") size = model.input_size C = CN() override_base = f"config://{classifier}/_base_override_{name}.yaml" if URI(override_base).exists(): C._BASE_ = override_base else: suffix = "_pretrained" if model.pretrained else "" C._BASE_ = f"config://{classifier}/_base{suffix}.yaml" C.DATASET = CN() C.DATASET.TRANSFORMS = CN() C.DATASET.TRANSFORMS.CENTER_CROP = (50, 50) \ if center_crop else None C.DATASET.TRANSFORMS.RESIZE = size C.TRAINING = CN() C.TRAINING.MODEL = CN() C.TRAINING.MODEL.REGISTRY = classifier.upper() C.TRAINING.MODEL.NAME = name with config_file.open("w") as f: C.dump(stream=f)
def perform_evaluation(classifier: str): """Function to set up the CLI for the evaluation script. Args: classifier (str): the classifier """ parser = argparse.ArgumentParser(description="Evaluate trained models.") parser.add_argument("--model", help=f"the model to evaluate (if unspecified, all models in 'runs://{classifier}' will be evaluated)", type=str, default=None) parser.add_argument("--dataset", help="the dataset to evaluate (if unspecified, train and val will be evaluated)", type=str, default=None, choices=[x.value for x in Datasets]) parser.add_argument("--out", help="output folder", type=str, default=f"results://{classifier}") parser.add_argument("--find-mistakes", help="whether to output all misclassification images", dest="find_mistakes", action="store_true") parser.set_defaults(find_mistakes=False) args = parser.parse_args() # Evaluate output_folder = URI(args.out) output_folder.mkdir(parents=True, exist_ok=True) output_csv = output_folder / "evaluate.csv" with output_csv.open("w") as f: models = list(URI(f"runs://{classifier}").glob("*/*.pt")) \ if args.model is None else [URI(args.model)] datasets = [Datasets.TRAIN, Datasets.VAL] \ if args.dataset is None else [d for d in Datasets if d.value == args.dataset] for i, model in enumerate(models): logger.info(f"Processing model {i+1}/{len(models)}") f.write(evaluate(model, datasets, output_folder, find_mistakes=args.find_mistakes, include_heading=i == 0) + "\n")
def build_dataset(cfg: CN, mode: Datasets) -> torch.utils.data.Dataset: """Build a dataset from its configuration. Args: cfg (CN): the config object mode (Datasets): the split (important to figure out which transforms to apply) Returns: torch.utils.data.Dataset: the dataset """ transform = build_transforms(cfg, mode) dataset = torchvision.datasets.ImageFolder(root=URI(cfg.DATASET.PATH) / mode.value, transform=transform) return dataset
def __init__(self, classifiers_folder: Path = URI("models://")): """Constructor. Args: classifiers_folder (Path, optional): the path to the classifiers (supplying a different path is especially useful because the transfer learning classifiers are located at ``models://transfer_learning``). Defaults to ``models://``. """ self._corner_detection_cfg = CN.load_yaml_with_base( "config://corner_detection.yaml") self._occupancy_cfg, self._occupancy_model = self._load_classifier( classifiers_folder / "occupancy_classifier") self._occupancy_transforms = build_transforms( self._occupancy_cfg, mode=Datasets.TEST) self._pieces_cfg, self._pieces_model = self._load_classifier( classifiers_folder / "piece_classifier") self._pieces_transforms = build_transforms( self._pieces_cfg, mode=Datasets.TEST) self._piece_classes = np.array(list(map(name_to_piece, self._pieces_cfg.DATASET.CLASSES)))
def download_file(url: str, destination: os.PathLike, show_size: bool = False): """Download a file from a URL to a destination. Args: url (str): the URL destination (os.PathLike): the destination show_size (bool, optional): whether to display a progress bar. Defaults to False. """ destination = URI(destination) logger.info(f"Downloading {url} to {destination}") response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 if show_size: progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with destination.open("wb") as f: for data in response.iter_content(block_size): if show_size: progress_bar.update(len(data)) f.write(data) if show_size: progress_bar.close()
"""Script to perform a single inference using the fine-tuned system on the new dataset. .. code-block:: console $ python -m chesscog.transfer_learning.recognition --help usage: recognition.py [-h] [--white] [--black] file Run the chess recognition pipeline on an input image positional arguments: file path to the input image optional arguments: -h, --help show this help message and exit --white indicate that the image is from the white player's perspective (default) --black indicate that the image is from the black player's perspective """ from recap import URI import functools from chesscog.recognition.recognition import main if __name__ == "__main__": from chesscog.transfer_learning.download_models import ensure_models main(URI("models://transfer_learning"), setup=functools.partial(ensure_models, show_size=True))
parser.add_argument( "--dataset", help= "the dataset to evaluate (if unspecified, train and val will be evaluated)", type=str, default=None, choices=[x.value for x in Datasets]) parser.add_argument("--out", help="output folder", type=str, default=f"results://recognition") parser.add_argument("--save-fens", help="store predicted and actual FEN strings", action="store_true", dest="save_fens") parser.set_defaults(save_fens=False) args = parser.parse_args() output_folder = URI(args.out) output_folder.mkdir(parents=True, exist_ok=True) datasets = [Datasets.TRAIN, Datasets.VAL] \ if args.dataset is None else [d for d in Datasets if d.value == args.dataset] recognizer = TimedChessRecognizer() for dataset in datasets: folder = URI("data://render") / dataset.value logger.info(f"Evaluating dataset {folder}") with (output_folder / f"{dataset.value}.csv").open("w") as f: evaluate(recognizer, f, folder, save_fens=args.save_fens)
bottom = get_nonmax_supressed(ymin - 1) if top.sum() > bottom.sum(): ymax += 1 else: ymin -= 1 return ymin, ymax if __name__ == "__main__": import matplotlib.pyplot as plt import argparse parser = argparse.ArgumentParser(description="Chessboard corner detector.") parser.add_argument("file", type=str, help="URI of the input image file") parser.add_argument("--config", type=str, help="path to the config file", default="config://corner_detection.yaml") args = parser.parse_args() cfg = CN.load_yaml_with_base(args.config) filename = URI(args.file) img = cv2.imread(str(filename)) corners = find_corners(cfg, img) fig = plt.figure() fig.canvas.set_window_title("Corner detection output") plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) plt.scatter(*corners.T, c="r") plt.axis("off") plt.show()
optional arguments: -h, --help show this help message and exit """ import chess.pgn from pathlib import Path import numpy as np import argparse from recap import URI if __name__ == "__main__": argparse.ArgumentParser( description="Create the fens.txt file by selecting 2%% of the positions from games.pgn.").parse_args() dataset_path = URI("data://games.pgn") fens_path = URI("data://fens.txt") fens = set() with dataset_path.open("r") as pgn: while (game := chess.pgn.read_game(pgn)) is not None: board = game.board() moves = list(game.mainline_moves()) moves_mask = np.random.randint(0, 50, len(moves)) == 0 for move, mask in zip(moves, moves_mask): board.push(move) if mask: color = "W" if board.turn == chess.WHITE else "B" fens.add(color + board.board_fen()) with fens_path.open("w") as f:
@listify def _add_parameter(key: str, values: typing.Iterable[typing.Any], cfgs: typing.List[CN]) -> list: for value in values: for cfg in cfgs: cfg = cfg.clone() cfg_node = cfg *key_items, final_key = key.split(".") for k in key_items: cfg_node = cfg_node[k] cfg_node[final_key] = value yield cfg def _is_valid_cfg(cfg: CN) -> bool: return cfg.EDGE_DETECTION.LOW_THRESHOLD <= cfg.EDGE_DETECTION.HIGH_THRESHOLD if __name__ == "__main__": argparse.ArgumentParser( description="Create YAML config files for grid search.").parse_args() cfg_folder = URI("config://corner_detection") cfg = CN.load_yaml_with_base(cfg_folder / "_base.yaml") cfgs = [cfg] for k, v in parameters.items(): cfgs = _add_parameter(k, v, cfgs) cfgs = filter(_is_valid_cfg, cfgs) for i, cfg in enumerate(cfgs, 1): with (cfg_folder / f"generated_{i}.yaml").open("w") as f: cfg.dump(stream=f)
""" import numpy as np from logging import getLogger from recap import URI import argparse logger = getLogger(__name__) if __name__ == "__main__": argparse.ArgumentParser( description="Split the dataset into train/val/test.").parse_args() val_split = .03 test_split = .1 render_dir = URI("data://render") ids = np.array([x.stem for x in render_dir.glob("*.json")]) if len(ids) == 0: logger.warning( "No samples found in 'data://render', either you did not download the datset yet or you have already split it." ) np.random.seed(42) ids = np.random.permutation(ids) sample_sizes = (np.array([val_split, test_split]) * len(ids)).astype( np.int32) val, test, train = np.split(ids, sample_sizes) datasets = {"val": val, "test": test, "train": train} print( f"{len(ids)} samples will be split into {len(train)} train, {len(val)} val, {len(test)} test." )
def test_virtural_uri(): from chesscog.core.io import _DATA_DIR a = URI("data://a/b") assert str(a) == str(_DATA_DIR / "a" / "b")
k[len("config."):]: v if not hasattr(v, "item") else v.item() for k, v in row.items() if k.startswith("config.") } cfg = CN.load_yaml_with_base("config://corner_detection/_base.yaml") cfg.merge_with_dict(configs) with (output_folder / f"{i:04d}.yaml").open("w") as f: cfg.dump(stream=f) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description= "Get the best n configs of the results obtained via grid search") parser.add_argument("--n", help="the number of configs to retain", type=int, default=100) parser.add_argument( "--in", dest="input", help="the CSV file containing the results of the grid search", type=str, default="results://corner_detection/evaluate.csv") parser.add_argument("--out", help="the output folder for the YAML files", type=str, default="config://corner_detection/refined") args = parser.parse_args() _find_best_configs(args.n, URI(args.input), URI(args.out))
"--dataset", help= "the dataset to evaluate (if unspecified, train and val will be evaluated)", type=str, default=None, choices=[x.value for x in Datasets]) parser.add_argument("--out", help="output folder", type=str, default=f"results://corner_detection") parser.set_defaults(find_mistakes=False) args = parser.parse_args() datasets = [Datasets.TRAIN, Datasets.VAL] \ if args.dataset is None else [d for d in Datasets if d.value == args.dataset] config_path = URI(args.config) if config_path.is_dir(): cfgs = URI(args.config).glob("*.yaml") cfgs = filter(lambda x: not x.name.startswith("_"), cfgs) cfgs = sorted(cfgs) cfgs = map(CN.load_yaml_with_base, cfgs) cfgs = list(cfgs) else: cfgs = [CN.load_yaml_with_base(config_path)] output_folder = URI(args.out) output_folder.mkdir(parents=True, exist_ok=True) with (output_folder / "evaluate.csv").open("w") as f: cfg_headers = None for i, cfg in enumerate(cfgs, 1): params = cfg.params_dict()
.. code-block:: console $ python -m chesscog.data_synthesis.download_pgn --help usage: download_pgn.py [-h] Download Magnus Carlsen's chess games to data://games.pgn. optional arguments: -h, --help show this help message and exit """ import urllib.request import zipfile import argparse from recap import URI if __name__ == "__main__": argparse.ArgumentParser( description="Download Magnus Carlsen's chess games to data://games.pgn." ).parse_args() zip_file = URI("data://games.zip") urllib.request.urlretrieve("https://www.pgnmentor.com/players/Carlsen.zip", zip_file) with zipfile.ZipFile(zip_file) as zip_f: with zip_f.open( "Carlsen.pgn", "r") as in_f, URI("data://games.pgn").open("wb") as out_f: out_f.write(in_f.read())
from pathlib import Path import matplotlib.pyplot as plt import cv2 from PIL import Image, ImageDraw import json import numpy as np import chess import os import shutil from recap import URI import argparse from chesscog.core import sort_corner_points from chesscog.core.dataset import piece_name RENDERS_DIR = URI("data://render") OUT_DIR = URI("data://pieces") SQUARE_SIZE = 50 BOARD_SIZE = 8 * SQUARE_SIZE IMG_SIZE = BOARD_SIZE * 2 MARGIN = (IMG_SIZE - BOARD_SIZE) / 2 MIN_HEIGHT_INCREASE, MAX_HEIGHT_INCREASE = 1, 3 MIN_WIDTH_INCREASE, MAX_WIDTH_INCREASE = .25, 1 OUT_WIDTH = int((1 + MAX_WIDTH_INCREASE) * SQUARE_SIZE) OUT_HEIGHT = int((1 + MAX_HEIGHT_INCREASE) * SQUARE_SIZE) def crop_square(img: np.ndarray, square: chess.Square, turn: chess.Color) -> np.ndarray: """Crop a chess square from the warped input image for piece classification.
import osfclient.cli import typing import zipfile import tempfile from pathlib import Path from types import SimpleNamespace from recap import URI from logging import getLogger logger = getLogger(__name__) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Download the rendered dataset.").parse_args() folder = URI("data://render") with tempfile.TemporaryDirectory() as tmp: logger.info("Downloading rendered dataset from OSF") tmp = Path(tmp) args = SimpleNamespace(project="xf3ka", output=str(tmp), username=None) osfclient.cli.clone(args) shutil.rmtree(folder, ignore_errors=True) os.makedirs(folder.parent, exist_ok=True) shutil.move(tmp / "osfstorage", folder) logger.info("Merging train dataset") try: os.system( f"zip -s 0 {folder / 'train.zip'} --out {folder / 'train_full.zip'}" ) except Exception: raise Exception(f"Please manually unpack the ZIP archives at {folder}")
fill=outline_color ) draw.text( (x + margin, y + margin), text, fill=text_color, font=font) def _visualize_groundtruth(img: Image, label: dict): _draw_board_edges(img, label["corners"]) _draw_bounding_boxes(img, label["pieces"]) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Visualize a sample from the dataset.") parser.add_argument("--file", type=str, help="path to image file", default="data://render/train/3828.png") args = parser.parse_args() img_file = URI(args.file) json_file = img_file.parent / f"{img_file.stem}.json" img = Image.open(img_file) with json_file.open("r") as f: label = json.load(f) _visualize_groundtruth(img, label) img.show()
from recap import URI, CfgNode as CN import cv2 import functools import json import chess from pathlib import Path import matplotlib.pyplot as plt import argparse from chesscog.corner_detection import find_corners, resize_image from chesscog.occupancy_classifier import create_dataset as create_occupancy_dataset from chesscog.piece_classifier import create_dataset as create_pieces_dataset from chesscog.core.dataset import Datasets DATASET_DIR = URI("data://transfer_learning") def _add_corners_to_train_labels(input_dir: Path): corner_detection_cfg = CN.load_yaml_with_base( "config://corner_detection.yaml") for subset in (x.value for x in (Datasets.TRAIN, Datasets.TEST)): for img_file in (input_dir / subset).glob("*.png"): img = cv2.imread(str(img_file)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img, img_scale = resize_image(corner_detection_cfg, img) corners = find_corners(corner_detection_cfg, img) corners = corners / img_scale json_file = img_file.parent / f"{img_file.stem}.json" with json_file.open("r") as f:
from pathlib import Path import matplotlib.pyplot as plt import cv2 from PIL import Image, ImageDraw import json import numpy as np import chess import os import shutil from recap import URI import argparse from chesscog.core import sort_corner_points RENDERS_DIR = URI("data://render") OUT_DIR = URI("data://occupancy") SQUARE_SIZE = 50 BOARD_SIZE = 8 * SQUARE_SIZE IMG_SIZE = BOARD_SIZE + 2 * SQUARE_SIZE def crop_square(img: np.ndarray, square: chess.Square, turn: chess.Color) -> np.ndarray: """Crop a chess square from the warped input image for occupancy classification. Args: img (np.ndarray): the warped input image square (chess.Square): the square to crop turn (chess.Color): the current player
import typing import pandas as pd import argparse from recap import URI import sys if __name__ == "__main__": parser = argparse.ArgumentParser( description="Prepare distribution of mistakes per board for LaTeX") parser.add_argument("--results", help="parent results folder", type=str, default="results://recognition") parser.add_argument("--dataset", help="the dataset to evaluate", type=str, default="train") args = parser.parse_args() # Load data df = pd.read_csv(URI(args.results) / f"{args.dataset}.csv") # Filter out samples where the corners could not be detected # df = df[(df["num_incorrect_corners"] != 4) | (df["error"] != "None")] counts = df["num_incorrect_squares"].value_counts() counts = counts / counts.sum() * 100 counts = counts[counts.index != 0] for i, count in zip(counts.index, counts): print(f"({i:2d},{count:5.02f})") print( f"Proportion of boards classified with >=2 mistakes: {counts[counts.index >= 2].sum():.02}%")
import pandas as pd import re import argparse from recap import URI if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare results for LaTeX") parser.add_argument("--classifier", type=str, choices=["occupancy_classifier", "piece_classifier"], default="occupancy_classifier") args = parser.parse_args() classifier = args.classifier df = pd.read_csv(URI(f"results://{classifier}/evaluate.csv")) pattern = re.compile(r"^confusion_matrix/(\d+)/(\d+)$") off_diagonal_confusion_matrix_mask = df.columns \ .map(pattern.match) \ .map(lambda match: match and match.group(1) != match.group(2)) \ .fillna(False) df["misclassified"] = df.loc[:, off_diagonal_confusion_matrix_mask].sum( axis="columns") df["accuracy"] *= 100 df_train = df[df["dataset"] == "train"] \ .set_index("model") \ .drop(columns="dataset") \ .rename(columns=lambda x: f"train_{x}") df_val = df[df["dataset"] == "val"] \
def test_ensure_model(ensure_model: typing.Callable, name: str): ensure_model(show_size=False) assert len(list(URI(f"models://{name}").glob("*.pt"))) > 0
if __name__ == "__main__": parser = argparse.ArgumentParser( description="Evaluate the chess recognition system end-to-end.") parser.add_argument( "--dataset", help= "the dataset to evaluate (if unspecified, train and test will be evaluated)", type=str, default=None, choices=[x.value for x in Datasets]) parser.add_argument("--out", help="output folder", type=str, default=f"results://transfer_learning/recognition") parser.set_defaults(find_mistakes=False) args = parser.parse_args() output_folder = URI(args.out) output_folder.mkdir(parents=True, exist_ok=True) datasets = [Datasets.TRAIN, Datasets.TEST] \ if args.dataset is None else [d for d in Datasets if d.value == args.dataset] recognizer = TimedChessRecognizer(URI("models://transfer_learning")) for dataset in datasets: folder = URI("data://transfer_learning/images") / dataset.value logger.info(f"Evaluating dataset {folder}") with (output_folder / f"{dataset.value}.csv").open("w") as f: evaluate(recognizer, f, folder, save_fens=True)