def __init__(self,
                 root,
                 transform=None,
                 pre_transform=None,
                 fold="train",
                 small=False,
                 pool_size=1):
        assert fold in [
            "train", "val", "test_images"
        ], "Input fold={} should be in [\"train\", \"val\", \"test_images\"]".format(
            fold)
        if fold == "test_images":
            print_utils.print_error(
                "ERROR: fold {} not yet implemented!".format(fold))
            exit()
        self.root = root
        self.fold = fold
        makedirs(self.processed_dir)
        self.small = small
        if self.small:
            print_utils.print_info(
                "INFO: Using small version of the Mapping challenge dataset.")
        self.pool_size = pool_size

        self.coco = None
        self.image_id_list = self.load_image_ids()
        self.stats_filepath = os.path.join(self.processed_dir, "stats.pt")
        self.stats = None
        if os.path.exists(self.stats_filepath):
            self.stats = torch.load(self.stats_filepath)
        self.processed_flag_filepath = os.path.join(
            self.processed_dir,
            "processed-flag-small" if self.small else "processed-flag")

        super(MappingChallenge, self).__init__(root, transform, pre_transform)
Esempio n. 2
0
    def load_checkpoint(self):
        """
        Loads best val checkpoint in checkpoints_dirpath
        """
        filepaths = python_utils.get_filepaths(
            self.checkpoints_dirpath,
            startswith_str="checkpoint.best_val.",
            endswith_str=".tar")
        if len(filepaths):
            filepaths = sorted(filepaths)
            filepath = filepaths[
                -1]  # Last best val checkpoint filepath in case there is more than one
            if self.gpu == 0:
                print_utils.print_info(
                    "Loading best val checkpoint: {}".format(filepath))
        else:
            # No best val checkpoint fount: find last checkpoint:
            filepaths = python_utils.get_filepaths(
                self.checkpoints_dirpath,
                endswith_str=".tar",
                startswith_str="checkpoint.")
            if len(filepaths) == 0:
                raise FileNotFoundError(
                    "No checkpoint could be found at that location.")
            filepaths = sorted(filepaths)
            filepath = filepaths[-1]  # Last checkpoint
            if self.gpu == 0:
                print_utils.print_info(
                    "Loading last checkpoint: {}".format(filepath))
        # map_location is used to load on current device:
        checkpoint = torch.load(filepath,
                                map_location="cuda:{}".format(self.gpu))

        self.model.module.load_state_dict(checkpoint['model_state_dict'])
Esempio n. 3
0
def load_checkpoint(model, checkpoints_dirpath, device):
    """
    Loads best val checkpoint in checkpoints_dirpath
    """
    filepaths = python_utils.get_filepaths(
        checkpoints_dirpath,
        startswith_str="checkpoint.best_val.",
        endswith_str=".tar")
    if len(filepaths):
        filepaths = sorted(filepaths)
        filepath = filepaths[
            -1]  # Last best val checkpoint filepath in case there is more than one
        print_utils.print_info(
            "Loading best val checkpoint: {}".format(filepath))
    else:
        # No best val checkpoint fount: find last checkpoint:
        filepaths = python_utils.get_filepaths(checkpoints_dirpath,
                                               endswith_str=".tar",
                                               startswith_str="checkpoint.")
        filepaths = sorted(filepaths)
        filepath = filepaths[-1]  # Last checkpoint
        print_utils.print_info("Loading last checkpoint: {}".format(filepath))

    device = torch.device(device)
    checkpoint = torch.load(
        filepath,
        map_location=device)  # map_location is used to load on current device

    model.load_state_dict(checkpoint['model_state_dict'])

    return model
Esempio n. 4
0
def inference_no_patching(config, model, tile_data):
    with torch.no_grad():
        batch = {
            "image": tile_data["image"],
            "image_mean": tile_data["image_mean"],
            "image_std": tile_data["image_std"]
        }
        try:
            pred, batch = network_inference(config, model, batch)
        except RuntimeError as e:
            print_utils.print_error("ERROR: " + str(e))
            if 1 < config["optim_params"]["eval_batch_size"]:
                print_utils.print_info(
                    "INFO: Try lowering the effective batch_size (which is {} currently). "
                    "Note that in eval mode, the effective bath_size is equal to double the batch_size "
                    "because gradients do not need to "
                    "be computed so double the memory is available. "
                    "You can override the effective batch_size with the eval_batch_size parameter."
                    .format(config["optim_params"]["eval_batch_size"]))
            else:
                print_utils.print_info(
                    "INFO: The effective batch_size is 1 but the GPU still ran out of memory."
                    "You can specify parameters to split the image into patches for inference:\n"
                    "--eval_patch_size is the size of the patch and should be chosen as big as memory allows.\n"
                    "--eval_patch_overlap (optional, default=200) adds overlaps between patches to avoid border artifacts."
                    .format(config["optim_params"]["eval_batch_size"]))
            raise e

        tile_data["seg"] = pred["seg"]
        if "crossfield" in pred:
            tile_data["crossfield"] = pred["crossfield"]

    return tile_data
def main():
    args = get_args()
    print_utils.print_info(
        f"INFO: evaluating {len(args.pred_filepath)} predictions.")

    # Match files together
    im_gt_pred_filepaths = match_im_gt_pred(args.im_filepath, args.gt_filepath,
                                            args.pred_filepath)

    pool = Pool()
    metrics_iou_list = list(
        tqdm(pool.imap(partial(eval_one, overwrite=args.overwrite),
                       im_gt_pred_filepaths),
             desc="Compute eval metrics",
             total=len(im_gt_pred_filepaths)))

    # Aggregate metrics and IoU
    aggr_metrics = {"max_angle_diffs": []}
    aggr_iou = {"intersection": 0, "union": 0}
    for metrics_iou in metrics_iou_list:
        if metrics_iou["metrics"]:
            aggr_metrics["max_angle_diffs"] += metrics_iou["metrics"][
                "max_angle_diffs"]
        if metrics_iou["iou"]:
            aggr_iou["intersection"] += metrics_iou["iou"]["intersection"]
            aggr_iou["union"] += metrics_iou["iou"]["union"]
    aggr_iou["iou"] = aggr_iou["intersection"] / aggr_iou["union"]

    aggr_metrics_filepath = os.path.join(
        os.path.dirname(args.pred_filepath[0]), "aggr_metrics.json")
    aggr_iou_filepath = os.path.join(os.path.dirname(args.pred_filepath[0]),
                                     "aggr_iou.json")
    python_utils.save_json(aggr_metrics_filepath, aggr_metrics)
    python_utils.save_json(aggr_iou_filepath, aggr_iou)
def eval_process(gpu, config, shared_dict, barrier):
    from frame_field_learning.evaluate import evaluate

    torch.manual_seed(0)  # Ensure same seed for all processes
    # --- Find data directory --- #
    root_dir_candidates = [
        os.path.join(data_dirpath, config["dataset_params"]["root_dirname"])
        for data_dirpath in config["data_dir_candidates"]
    ]
    root_dir, paths_tried = python_utils.choose_first_existing_path(
        root_dir_candidates, return_tried_paths=True)
    if root_dir is None:
        print_utils.print_error(
            "GPU {} -> ERROR: Data root directory amongst \"{}\" not found!".
            format(gpu, paths_tried))
        raise NotADirectoryError(
            f"Couldn't find a directory in {paths_tried} (gpu:{gpu})")
    print_utils.print_info("GPU {} -> Using data from {}".format(
        gpu, root_dir))
    config["data_root_dir"] = root_dir

    # --- Get dataset
    # - CHANGE HERE TO ADD YOUR OWN DATASET
    eval_ds, = get_folds(
        config, root_dir,
        folds=config["fold"])  # config["fold"] is already a list (of length 1)

    # --- Instantiate backbone network (its backbone will be used to extract features)
    backbone = get_backbone(config["backbone_params"])

    evaluate(gpu, config, shared_dict, barrier, eval_ds, backbone)
def launch_train(args):
    assert args.config is not None, "Argument --config must be specified. Run 'python main.py --help' for help on arguments."
    config = run_utils.load_config(args.config)
    if config is None:
        print_utils.print_error(
            "ERROR: cannot continue without a config file. Exiting now...")
        sys.exit()
    config["runs_dirpath"] = args.runs_dirpath
    if args.run_name is not None:
        config["run_name"] = args.run_name
    config["new_run"] = args.new_run
    config["init_run_name"] = args.init_run_name
    if args.samples is not None:
        config["samples"] = args.samples
    if args.batch_size is not None:
        config["optim_params"]["batch_size"] = args.batch_size
    if args.max_epoch is not None:
        config["optim_params"]["max_epoch"] = args.max_epoch

    if args.fold is None:
        if "fold" in config:
            fold = set(config["fold"])
        else:
            fold = {"train"}  # Default values for train
    else:
        fold = set(args.fold)
    assert fold == {"train"} or fold == {"train", "val"}, \
        "Argument fold when training should be either: ['train'] or ['train', 'val']"
    config["fold"] = list(fold)
    print_utils.print_info("Training on fold(s): {}".format(config["fold"]))

    config["nodes"] = args.nodes
    config["gpus"] = args.gpus
    config["nr"] = args.nr
    config["world_size"] = args.gpus * args.nodes

    # --- Load params in config set as relative path to another JSON file
    config = run_utils.load_defaults_in_config(
        config, filepath_key="defaults_filepath")

    # Setup num_workers per process:
    if config["num_workers"] is None:
        config["num_workers"] = int(torch.multiprocessing.cpu_count() /
                                    config["gpus"])

    # --- Distributed init:
    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port
    manager = torch.multiprocessing.Manager()
    shared_dict = manager.dict()
    shared_dict["run_dirpath"] = None
    shared_dict["init_checkpoints_dirpath"] = None
    barrier = manager.Barrier(args.gpus)

    torch.multiprocessing.spawn(train_process,
                                nprocs=args.gpus,
                                args=(config, shared_dict, barrier))
def main():
    args = get_args()
    print_utils.print_info(
        f"INFO: converting {len(args.filepath)} seg images.")

    pool = Pool()
    list(
        tqdm(pool.imap(partial(convert_one, out_dirpath=args.out_dirpath),
                       args.filepath),
             desc="RGB to Gray",
             total=len(args.filepath)))
Esempio n. 9
0
def inference_from_filepath(config, in_filepaths, backbone):
    # --- Online transform performed on the device (GPU):
    eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(config)

    print("Loading model...")
    model = FrameFieldModel(config, backbone=backbone, eval_transform=eval_online_cuda_transform)
    model.to(config["device"])
    checkpoints_dirpath = run_utils.setup_run_subdir(config["eval_params"]["run_dirpath"], config["optim_params"]["checkpoints_dirname"])
    model = inference.load_checkpoint(model, checkpoints_dirpath, config["device"])
    model.eval()

    # Read image
    pbar = tqdm(in_filepaths, desc="Infer images")
    for in_filepath in pbar:
        pbar.set_postfix(status="Loading image")
        image = skimage.io.imread(in_filepath)
        if 3 < image.shape[2]:
            print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
            image = image[:, :, :3]
        elif image.shape[2] < 3:
            print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
            raise ValueError
        image_float = image / 255
        mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
        sample = {
            "image": torchvision.transforms.functional.to_tensor(image)[None, ...],
            "image_mean": torch.from_numpy(mean)[None, ...],
            "image_std": torch.from_numpy(std)[None, ...],
            "image_filepath": [in_filepath],
        }

        pbar.set_postfix(status="Inference")
        tile_data = inference.inference(config, model, sample, compute_polygonization=True)

        tile_data = local_utils.batch_to_cpu(tile_data)

        # Remove batch dim:
        tile_data = local_utils.split_batch(tile_data)[0]

        pbar.set_postfix(status="Saving output")
        base_filepath = os.path.splitext(in_filepath)[0]
        if config["compute_seg"]:
            seg_mask = 0.5 < tile_data["seg"][0]
            save_utils.save_seg_mask(seg_mask, base_filepath + ".mask", tile_data["image_filepath"])
            save_utils.save_seg(tile_data["seg"], base_filepath, "seg", tile_data["image_filepath"])
            save_utils.save_seg_luxcarta_format(tile_data["seg"], base_filepath, "seg_luxcarta_format", tile_data["image_filepath"])
        if config["compute_crossfield"]:
            save_utils.save_crossfield(tile_data["crossfield"], base_filepath, "crossfield")
        if "poly_viz" in config["eval_params"]["save_individual_outputs"] and \
                config["eval_params"]["save_individual_outputs"]["poly_viz"]:
            save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], base_filepath, "poly_viz")
def train_process(gpu, config, shared_dict, barrier):
    from frame_field_learning.train import train

    print_utils.print_info(
        "GPU {} -> Ready. There are {} GPU(s) available on this node.".format(
            gpu, torch.cuda.device_count()))

    torch.manual_seed(0)  # Ensure same seed for all processes
    # --- Find data directory --- #
    root_dir_candidates = [
        os.path.join(data_dirpath, config["dataset_params"]["root_dirname"])
        for data_dirpath in config["data_dir_candidates"]
    ]
    root_dir, paths_tried = python_utils.choose_first_existing_path(
        root_dir_candidates, return_tried_paths=True)
    if root_dir is None:
        print_utils.print_error(
            "GPU {} -> ERROR: Data root directory amongst \"{}\" not found!".
            format(gpu, paths_tried))
        exit()
    print_utils.print_info("GPU {} -> Using data from {}".format(
        gpu, root_dir))

    # --- Get dataset splits
    # - CHANGE HERE TO ADD YOUR OWN DATASET
    # We have to adapt the config["fold"] param to the folds argument of the get_folds function
    fold = set(config["fold"])
    if fold == {"train"}:
        # Val will be used for evaluating the model after each epoch:
        train_ds, val_ds = get_folds(config, root_dir, folds=["train", "val"])
    elif fold == {"train", "val"}:
        # Both train and val are meant to be used for training
        train_ds, = get_folds(config, root_dir, folds=["train_val"])
        val_ds = None
    else:
        # Should not arrive here since main makes sure config["fold"] is either one of the above
        print_utils.print_error("ERROR: specified folds not recognized!")
        raise NotImplementedError

    # --- Instantiate backbone network
    if config["backbone_params"]["name"] in ["deeplab50", "deeplab101"]:
        assert 1 < config["optim_params"]["batch_size"], \
            "When using backbone {}, batch_size has to be at least 2 for the batchnorm of the ASPPPooling to work."\
                .format(config["backbone_params"]["name"])
    backbone = get_backbone(config["backbone_params"])

    # --- Launch training
    train(gpu, config, shared_dict, barrier, train_ds, val_ds, backbone)
Esempio n. 11
0
def eval_coco(config):
    assert len(config["fold"]) == 1, "There should be only one specified fold"
    fold = config["fold"][0]
    if fold != "test":
        raise NotImplementedError

    pool = Pool(processes=config["num_workers"])

    # Find data dir
    root_dir_candidates = [os.path.join(data_dirpath, config["dataset_params"]["root_dirname"]) for data_dirpath in
                           config["data_dir_candidates"]]
    root_dir, paths_tried = python_utils.choose_first_existing_path(root_dir_candidates, return_tried_paths=True)
    if root_dir is None:
        print_utils.print_error(
            "ERROR: Data root directory amongst \"{}\" not found!".format(paths_tried))
        exit()
    print_utils.print_info("Using data from {}".format(root_dir))
    raw_dir = os.path.join(root_dir, "raw")

    # Get run's eval results dir
    results_dirpath = os.path.join(root_dir, config["eval_params"]["results_dirname"])
    run_results_dirpath = run_utils.setup_run_dir(results_dirpath, config["eval_params"]["run_name"], check_exists=True)

    # Setup coco
    annType = 'segm'

    # initialize COCO ground truth api
    gt_annotation_filename = "annotation-small.json" if config["dataset_params"]["small"] else "annotation.json"
    gt_annotation_filepath = os.path.join(raw_dir, "val",
                                          gt_annotation_filename)  # We are using the original val fold as our test fold
    print_utils.print_info("INFO: Load gt from " + gt_annotation_filepath)
    cocoGt = COCO(gt_annotation_filepath)

    # image_id = 0
    # annotation_ids = cocoGt.getAnnIds(imgIds=image_id)
    # annotation_list = cocoGt.loadAnns(annotation_ids)
    # print(annotation_list)

    # initialize COCO detections api
    annotation_filename_list = fnmatch.filter(os.listdir(run_results_dirpath), fold + ".annotation.*.json")
    eval_one_partial = partial(eval_one, run_results_dirpath=run_results_dirpath, cocoGt=cocoGt, config=config, annType=annType, pool=pool)

    # with Pool(8) as p:
    #     r = list(tqdm(p.imap(eval_one_partial, annotation_filename_list), total=len(annotation_filename_list)))
    for annotation_filename in annotation_filename_list:
        eval_one_partial(annotation_filename)
Esempio n. 12
0
    def __init__(self, gpu: int, config: dict, shared_dict, barrier, model,
                 run_dirpath):
        self.gpu = gpu
        self.config = config
        assert 0 < self.config["eval_params"]["batch_size_mult"], \
            "batch_size_mult in polygonize_params should be at least 1."

        self.shared_dict = shared_dict
        self.barrier = barrier
        self.model = model

        self.checkpoints_dirpath = run_utils.setup_run_subdir(
            run_dirpath, config["optim_params"]["checkpoints_dirname"])

        self.eval_dirpath = os.path.join(config["data_root_dir"], "eval_runs",
                                         os.path.split(run_dirpath)[-1])
        if self.gpu == 0:
            os.makedirs(self.eval_dirpath, exist_ok=True)
            print_utils.print_info("Saving eval outputs to {}".format(
                self.eval_dirpath))
    def __init__(self,
                 loss_funcs,
                 weights,
                 epoch_thresholds=None,
                 pre_processes=None):
        """

        @param loss_funcs:
        @param weights:
        @param pre_processes: List of functions to call with 2 arguments (which are updated): pred_batch, gt_batch to compute only one values used by several losses.
        """
        super(MultiLoss, self).__init__()
        assert len(loss_funcs) == len(weights), \
            "Should have the same amount of loss_funcs ({}) and weights ({})".format(len(loss_funcs), len(weights))
        self.loss_funcs = torch.nn.ModuleList(loss_funcs)

        self.weights = []
        for weight in weights:
            if isinstance(weight, list):
                # Weight is a list of coefs corresponding to epoch_thresholds, they will be interpolated in-between
                self.weights.append(
                    scipy.interpolate.interp1d(epoch_thresholds,
                                               weight,
                                               bounds_error=False,
                                               fill_value=(weight[0],
                                                           weight[-1])))
            elif isinstance(weight, float) or isinstance(weight, int):
                self.weights.append(float(weight))
            else:
                raise TypeError(
                    f"Type {type(weight)} not supported as a loss coef weight."
                )

        self.pre_processes = pre_processes

        for loss_func, weight in zip(self.loss_funcs, self.weights):
            if weight == 0:
                print_utils.print_info(
                    f"INFO: loss '{loss_func.name}' has a weight of zero and thus won't affect grad update."
                )
Esempio n. 14
0
def inference_with_patching(config, model, tile_data):
    assert len(tile_data["image"].shape) == 4 and tile_data["image"].shape[0] == 1, \
        f"When using inference with patching, tile_data should have a batch size of 1, " \
        f"with image's shape being (1, C, H, W), not {tile_data['image'].shape}"
    with torch.no_grad():
        # Init tile outputs (image is (N, C, H, W)):
        height = tile_data["image"].shape[2]
        width = tile_data["image"].shape[3]
        seg_channels = config["seg_params"]["compute_interior"] \
                       + config["seg_params"]["compute_edge"] \
                       + config["seg_params"]["compute_vertex"]
        if config["compute_seg"]:
            tile_data["seg"] = torch.zeros((1, seg_channels, height, width),
                                           device=config["device"])
        if config["compute_crossfield"]:
            tile_data["crossfield"] = torch.zeros((1, 4, height, width),
                                                  device=config["device"])
        weight_map = torch.zeros(
            (1, 1, height, width), device=config["device"]
        )  # Count number of patches on top of each pixel

        # Split tile in patches:
        stride = config["eval_params"]["patch_size"] - config["eval_params"][
            "patch_overlap"]
        patch_boundingboxes = image_utils.compute_patch_boundingboxes(
            (height, width),
            stride=stride,
            patch_res=config["eval_params"]["patch_size"])
        # Compute patch pixel weights to merge overlapping patches back together smoothly:
        patch_weights = np.ones((config["eval_params"]["patch_size"] + 2,
                                 config["eval_params"]["patch_size"] + 2),
                                dtype=np.float)
        patch_weights[0, :] = 0
        patch_weights[-1, :] = 0
        patch_weights[:, 0] = 0
        patch_weights[:, -1] = 0
        patch_weights = scipy.ndimage.distance_transform_edt(patch_weights)
        patch_weights = patch_weights[1:-1, 1:-1]
        patch_weights = torch.tensor(patch_weights,
                                     device=config["device"]).float()
        patch_weights = patch_weights[
            None, None, :, :]  # Adding batch and channels dims

        # Predict on each patch and save in outputs:
        for bbox in tqdm(patch_boundingboxes,
                         desc="Running model on patches",
                         leave=False):
            # Crop data
            batch = {
                "image": tile_data["image"][:, :, bbox[0]:bbox[2],
                                            bbox[1]:bbox[3]],
                "image_mean": tile_data["image_mean"],
                "image_std": tile_data["image_std"],
            }
            # Send batch to device
            try:
                pred, batch = network_inference(config, model, batch)
            except RuntimeError as e:
                print_utils.print_error("ERROR: " + str(e))
                print_utils.print_info(
                    "INFO: Reduce --eval_patch_size until the patch fits in memory."
                )
                raise e

            if config["compute_seg"]:
                tile_data[
                    "seg"][:, :, bbox[0]:bbox[2],
                           bbox[1]:bbox[3]] += patch_weights * pred["seg"]
            if config["compute_crossfield"]:
                tile_data["crossfield"][:, :, bbox[0]:bbox[2], bbox[1]:bbox[
                    3]] += patch_weights * pred["crossfield"]
            weight_map[:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] += patch_weights

        # Take care of overlapping parts
        if config["compute_seg"]:
            tile_data["seg"] /= weight_map
        if config["compute_crossfield"]:
            tile_data["crossfield"] /= weight_map

    return tile_data
Esempio n. 15
0
def eval_one(annotation_filename, run_results_dirpath, cocoGt, config, annType, pool=None):
    print("---eval_one")
    annotation_name = os.path.splitext(annotation_filename)[0]
    if "samples" in config:
        stats_filepath = os.path.join(run_results_dirpath,
                                      "{}.stats.{}.{}.json".format("test", annotation_name, config["samples"]))
        metrics_filepath = os.path.join(run_results_dirpath,
                                      "{}.metrics.{}.{}.json".format("test", annotation_name, config["samples"]))
    else:
        stats_filepath = os.path.join(run_results_dirpath, "{}.stats.{}.json".format("test", annotation_name))
        metrics_filepath = os.path.join(run_results_dirpath, "{}.metrics.{}.json".format("test", annotation_name))

    res_filepath = os.path.join(run_results_dirpath, annotation_filename)
    if not os.path.exists(res_filepath):
        print_utils.print_warning("WARNING: result not found at filepath {}".format(res_filepath))
        return
    print_utils.print_info("Evaluate {} annotations:".format(annotation_filename))
    try:
        cocoDt = cocoGt.loadRes(res_filepath)
    except AssertionError as e:
        print_utils.print_error("ERROR: {}".format(e))
        print_utils.print_info("INFO: continuing by removing unrecognised images")
        res = json.load(open(res_filepath))
        print("Initial res length:", len(res))
        annsImgIds = [ann["image_id"] for ann in res]
        image_id_rm = set(annsImgIds) - set(cocoGt.getImgIds())
        print_utils.print_warning("Remove {} image ids!".format(len(image_id_rm)))
        new_res = [ann for ann in res if ann["image_id"] not in image_id_rm]
        print("New res length:", len(new_res))
        cocoDt = cocoGt.loadRes(new_res)
        # {4601886185638229705, 4602408603195004682, 4597274499619802317, 4600985465712755606, 4597238470822783353,
        #  4597418614807878173}


    # image_id = 0
    # annotation_ids = cocoDt.getAnnIds(imgIds=image_id)
    # annotation_list = cocoDt.loadAnns(annotation_ids)
    # print(annotation_list)

    if not os.path.exists(stats_filepath):
        # Run COCOeval
        cocoEval = COCOeval(cocoGt, cocoDt, annType)
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()

        # Save stats
        stats = {}
        stat_names = ["AP", "AP_50", "AP_75", "AP_S", "AP_M", "AP_L", "AR", "AR_50", "AR_75", "AR_S", "AR_M", "AR_L"]
        assert len(stat_names) == cocoEval.stats.shape[0]
        for i, stat_name in enumerate(stat_names):
            stats[stat_name] = cocoEval.stats[i]

        python_utils.save_json(stats_filepath, stats)
    else:
        print("COCO stats already computed, skipping...")

    if not os.path.exists(metrics_filepath):
        # Verify that cocoDt has polygonal segmentation masks and not raster masks:
        if isinstance(cocoDt.loadAnns(cocoDt.getAnnIds(imgIds=cocoDt.getImgIds()[0]))[0]["segmentation"], list):
            metrics = {}
            # Run additionnal metrics
            print_utils.print_info("INFO: Running contour metrics")
            contour_eval = ContourEval(cocoGt, cocoDt)
            max_angle_diffs = contour_eval.evaluate(pool=pool)
            metrics["max_angle_diffs"] = list(max_angle_diffs)
            python_utils.save_json(metrics_filepath, metrics)
    else:
        print("Contour metrics already computed, skipping...")
Esempio n. 16
0
    def evaluate(self, split_name: str, ds: torch.utils.data.DataLoader):

        # Prepare data saving:
        flag_filepath_format = os.path.join(self.eval_dirpath, split_name,
                                            "{}.flag")

        # Loading model
        self.load_checkpoint()
        self.model.eval()

        # Create pool for multiprocessing
        pool = None
        if not self.config["eval_params"]["patch_size"]:
            # If single image is not being split up, then a pool to process each sample in the batch makes sense
            pool = Pool(processes=self.config["num_workers"])

        compute_polygonization = self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"] or \
                                 self.config["eval_params"]["save_individual_outputs"]["poly_geojson"] or \
                                 self.config["eval_params"]["save_individual_outputs"]["poly_viz"] or \
                                 self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"]

        # Saving individual outputs to disk:
        save_individual_outputs = True in self.config["eval_params"][
            "save_individual_outputs"].values()
        saver_async = None
        if save_individual_outputs:
            save_outputs_partial = partial(
                save_utils.save_outputs,
                config=self.config,
                eval_dirpath=self.eval_dirpath,
                split_name=split_name,
                flag_filepath_format=flag_filepath_format)
            saver_async = async_utils.Async(save_outputs_partial)
            saver_async.start()

        # Saving aggregated outputs
        save_aggregated_outputs = True in self.config["eval_params"][
            "save_aggregated_outputs"].values()

        tile_data_list = []

        if self.gpu == 0:
            tile_iterator = tqdm(ds,
                                 desc="Eval {}: ".format(split_name),
                                 leave=True)
        else:
            tile_iterator = ds
        for tile_i, tile_data in enumerate(tile_iterator):
            # --- Inference, add result to tile_data_list
            if self.config["eval_params"]["patch_size"] is not None:
                # Cut image into patches for inference
                inference.inference_with_patching(self.config, self.model,
                                                  tile_data)
            else:
                # Feed images as-is to the model
                inference.inference_no_patching(self.config, self.model,
                                                tile_data)

            tile_data_list.append(tile_data)

            # --- Accumulate batches into tile_data_list until capacity is reached (or this is the last batch)
            if self.config["eval_params"]["batch_size_mult"] <= len(tile_data_list)\
                    or tile_i == len(tile_iterator) - 1:
                # Concat tensors of tile_data_list
                accumulated_tile_data = {}
                for key in tile_data_list[0].keys():
                    if isinstance(tile_data_list[0][key], list):
                        accumulated_tile_data[key] = [
                            item for _tile_data in tile_data_list
                            for item in _tile_data[key]
                        ]
                    elif isinstance(tile_data_list[0][key], torch.Tensor):
                        accumulated_tile_data[key] = torch.cat(
                            [_tile_data[key] for _tile_data in tile_data_list],
                            dim=0)
                    else:
                        raise TypeError(
                            f"Type {type(tile_data_list[0][key])} is not handled!"
                        )
                tile_data_list = []  # Empty tile_data_list
            else:
                # tile_data_list is not full yet, continue running inference...
                continue

            # --- Polygonize
            if compute_polygonization:
                crossfield = accumulated_tile_data[
                    "crossfield"] if "crossfield" in accumulated_tile_data else None
                accumulated_tile_data["polygons"], accumulated_tile_data[
                    "polygon_probs"] = polygonize.polygonize(
                        self.config["polygonize_params"],
                        accumulated_tile_data["seg"],
                        crossfield_batch=crossfield,
                        pool=pool)

            # --- Save output
            if self.config["eval_params"]["save_individual_outputs"]["seg_mask"] or \
                    self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]:
                # Take seg_interior:
                seg_pred_mask = self.config["eval_params"][
                    "seg_threshold"] < accumulated_tile_data["seg"][:, 0, ...]
                accumulated_tile_data["seg_mask"] = seg_pred_mask

            accumulated_tile_data = local_utils.batch_to_cpu(
                accumulated_tile_data)
            sample_list = local_utils.split_batch(accumulated_tile_data)

            # Save individual outputs:
            if save_individual_outputs:
                for sample in sample_list:
                    saver_async.add_work(sample)

            # Store aggregated outputs:
            if save_aggregated_outputs:
                self.shared_dict["name_list"].extend(
                    accumulated_tile_data["name"])
                if self.config["eval_params"]["save_aggregated_outputs"][
                        "stats"]:
                    y_pred = accumulated_tile_data["seg"][:, 0, ...].cpu()
                    if "gt_mask" in accumulated_tile_data:
                        y_true = accumulated_tile_data["gt_mask"][:, 0, ...]
                    elif "gt_polygons_image" in accumulated_tile_data:
                        y_true = accumulated_tile_data[
                            "gt_polygons_image"][:, 0, ...]
                    else:
                        raise ValueError(
                            "Either gt_mask or gt_polygons_image should be in accumulated_tile_data"
                        )
                    iou = measures.iou(
                        y_pred.reshape(y_pred.shape[0], -1),
                        y_true.reshape(y_true.shape[0], -1),
                        threshold=self.config["eval_params"]["seg_threshold"])
                    self.shared_dict["iou_list"].extend(iou.cpu().numpy())
                if self.config["eval_params"]["save_aggregated_outputs"][
                        "seg_coco"]:
                    for sample in sample_list:
                        annotations = save_utils.seg_coco(sample)
                        self.shared_dict["seg_coco_list"].extend(annotations)
                if self.config["eval_params"]["save_aggregated_outputs"][
                        "poly_coco"]:
                    for sample in sample_list:
                        annotations = save_utils.poly_coco(
                            sample["polygons"], sample["polygon_probs"],
                            sample["image_id"].item())
                        self.shared_dict["poly_coco_list"].append(
                            annotations
                        )  # annotations could be a dict, or a list
        # END of loop over samples

        # Save aggregated results
        if save_aggregated_outputs:
            self.barrier.wait(
            )  # Wait on all processes so that shared_dict is synchronized.
            if self.gpu == 0:
                if self.config["eval_params"]["save_aggregated_outputs"][
                        "stats"]:
                    print("Start saving stats:")
                    # Save sample_stats in CSV:
                    t1 = time.time()
                    stats_filepath = os.path.join(
                        self.eval_dirpath, "{}.stats.csv".format(split_name))
                    stats_file = open(stats_filepath, "w")
                    fnames = ["name", "iou"]
                    writer = csv.DictWriter(stats_file, fieldnames=fnames)
                    writer.writeheader()
                    for name, iou in sorted(zip(self.shared_dict["name_list"],
                                                self.shared_dict["iou_list"]),
                                            key=lambda pair: pair[0]):
                        writer.writerow({"name": name, "iou": iou})
                    stats_file.close()
                    print(f"Finished in {time.time() - t1:02}s")

                if self.config["eval_params"]["save_aggregated_outputs"][
                        "seg_coco"]:
                    print("Start saving seg_coco:")
                    t1 = time.time()
                    seg_coco_filepath = os.path.join(
                        self.eval_dirpath,
                        "{}.annotation.seg.json".format(split_name))
                    python_utils.save_json(
                        seg_coco_filepath,
                        list(self.shared_dict["seg_coco_list"]))
                    print(f"Finished in {time.time() - t1:02}s")

                if self.config["eval_params"]["save_aggregated_outputs"][
                        "poly_coco"]:
                    print("Start saving poly_coco:")
                    poly_coco_base_filepath = os.path.join(
                        self.eval_dirpath, f"{split_name}.annotation.poly")
                    t1 = time.time()
                    save_utils.save_poly_coco(
                        self.shared_dict["poly_coco_list"],
                        poly_coco_base_filepath)
                    print(f"Finished in {time.time() - t1:02}s")

        # Sync point of individual outputs
        if save_individual_outputs:
            print_utils.print_info(
                f"GPU {self.gpu} -> INFO: Finishing saving individual outputs."
            )
            saver_async.join()
            self.barrier.wait(
            )  # Wait on all processes so that all saver_asyncs are finished
Esempio n. 17
0
    def __init__(self,
                 root: str,
                 fold: str = "train",
                 pre_process: bool = True,
                 patch_size: int = None,
                 pre_transform=None,
                 transform=None,
                 small: bool = False,
                 pool_size: int = 1,
                 raw_dirname: str = "raw",
                 processed_dirname: str = "processed"):
        """

        @param root:
        @param fold:
        @param pre_process: If True, the dataset will be pre-processed first, saving training patches on disk. If False, data will be serve on-the-fly without any patching.
        @param patch_size:
        @param pre_transform:
        @param transform:
        @param small: If True, use a small subset of the dataset (for testing)
        @param pool_size:
        @param processed_dirname:
        """
        self.root = root
        self.fold = fold
        self.pre_process = pre_process
        self.patch_size = patch_size
        self.pre_transform = pre_transform
        self.transform = transform
        self.small = small
        if self.small:
            print_utils.print_info(
                "INFO: Using small version of the xView2 xBD dataset.")
        self.pool_size = pool_size
        self.raw_dirname = raw_dirname

        if self.pre_process:
            # Setup of pre-process
            self.processed_dirpath = os.path.join(self.root, processed_dirname,
                                                  self.fold)
            stats_filepath = os.path.join(
                self.processed_dirpath,
                "stats-small.pt" if self.small else "stats.pt")
            processed_relative_paths_filepath = os.path.join(
                self.processed_dirpath, "processed_paths-small.json"
                if self.small else "processed_paths.json")

            # Check if dataset has finished pre-processing by checking processed_relative_paths_filepath:
            if os.path.exists(processed_relative_paths_filepath):
                # Process done, load stats and processed_relative_paths
                self.stats = torch.load(stats_filepath)
                self.processed_relative_paths = python_utils.load_json(
                    processed_relative_paths_filepath)
            else:
                # Pre-process not finished, launch it:
                tile_info_list = self.get_tile_info_list()
                self.stats = self.process(tile_info_list)
                # Save stats
                torch.save(self.stats, stats_filepath)
                # Save processed_relative_paths
                self.processed_relative_paths = [
                    tile_info["processed_relative_filepath"]
                    for tile_info in tile_info_list
                ]
                python_utils.save_json(processed_relative_paths_filepath,
                                       self.processed_relative_paths)
        else:
            # Setup data sample list
            self.tile_info_list = self.get_tile_info_list()
Esempio n. 18
0
def main():
    # Test using transforms from the frame_field_learning project:
    from frame_field_learning import data_transforms

    config = {
        "data_dir_candidates":
        ["/data/titane/user/nigirard/data", "~/data", "/data"],
        "dataset_params": {
            "root_dirname": "xview2_xbd_dataset",
            "pre_process": True,
            "small": False,
            "data_patch_size": 725,
            "input_patch_size": 512,
            "train_fraction": 0.75
        },
        "num_workers":
        8,
        "data_aug_params": {
            "enable": True,
            "vflip": True,
            "affine": True,
            "scaling": [0.9, 1.1],
            "color_jitter": True,
            "device": "cuda"
        }
    }

    # Find data_dir
    data_dir = python_utils.choose_first_existing_path(
        config["data_dir_candidates"])
    if data_dir is None:
        print_utils.print_error("ERROR: Data directory not found!")
        exit()
    else:
        print_utils.print_info("Using data from {}".format(data_dir))
    root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"])

    # --- Transforms: --- #
    # --- pre-processing transform (done once then saved on disk):
    # --- Online transform done on the host (CPU):
    online_cpu_transform = data_transforms.get_online_cpu_transform(
        config, augmentations=config["data_aug_params"]["enable"])
    train_online_cuda_transform = data_transforms.get_online_cuda_transform(
        config, augmentations=config["data_aug_params"]["enable"])
    kwargs = {
        "pre_process": config["dataset_params"]["pre_process"],
        "transform": online_cpu_transform,
        "patch_size": config["dataset_params"]["data_patch_size"],
        "pre_transform": data_transforms.get_offline_transform_patch(),
        "small": config["dataset_params"]["small"],
        "pool_size": config["num_workers"],
    }
    # --- --- #
    fold = "train"
    if fold == "train":
        dataset = xView2Dataset(root_dir, fold="train", **kwargs)
    elif fold == "val":
        dataset = xView2Dataset(root_dir, fold="train", **kwargs)
    elif fold == "test":
        dataset = xView2Dataset(root_dir, fold="test", **kwargs)
    else:
        raise NotImplementedError

    print(f"dataset has {len(dataset)} samples.")
    print("# --- Sample 0 --- #")
    sample = dataset[0]
    for key, item in sample.items():
        print("{}: {}".format(key, type(item)))

    print("# --- Samples --- #")
    # for data in tqdm(dataset):
    #     pass

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=config["num_workers"])
    print("# --- Batches --- #")
    for batch in tqdm(data_loader):

        # batch["distances"] = batch["distances"].float()
        # batch["sizes"] = batch["sizes"].float()

        # im = np.array(batch["image"][0])
        # im = np.moveaxis(im, 0, -1)
        # skimage.io.imsave('im_before_transform.png', im)
        #
        # distances = np.array(batch["distances"][0])
        # distances = np.moveaxis(distances, 0, -1)
        # skimage.io.imsave('distances_before_transform.png', distances)
        #
        # sizes = np.array(batch["sizes"][0])
        # sizes = np.moveaxis(sizes, 0, -1)
        # skimage.io.imsave('sizes_before_transform.png', sizes)

        print("----")
        print(batch["name"])

        print("image:", batch["image"].shape, batch["image"].min().item(),
              batch["image"].max().item())
        im = np.array(batch["image"][0])
        im = np.moveaxis(im, 0, -1)
        skimage.io.imsave('im.png', im)

        if "gt_polygons_image" in batch:
            print("gt_polygons_image:", batch["gt_polygons_image"].shape,
                  batch["gt_polygons_image"].min().item(),
                  batch["gt_polygons_image"].max().item())
            seg = np.array(batch["gt_polygons_image"][0]) / 255
            seg = np.moveaxis(seg, 0, -1)
            seg_display = utils.get_seg_display(seg)
            seg_display = (seg_display * 255).astype(np.uint8)
            skimage.io.imsave("gt_seg.png", seg_display)

        if "gt_crossfield_angle" in batch:
            print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape,
                  batch["gt_crossfield_angle"].min().item(),
                  batch["gt_crossfield_angle"].max().item())
            gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
            gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
            skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)

        if "distances" in batch:
            print("distances:", batch["distances"].shape,
                  batch["distances"].float().min().item(),
                  batch["distances"].float().max().item())
            distances = np.array(batch["distances"][0])
            distances = np.moveaxis(distances, 0, -1)
            skimage.io.imsave('distances.png', distances)

        if "sizes" in batch:
            print("sizes:", batch["sizes"].shape,
                  batch["sizes"].float().min().item(),
                  batch["sizes"].float().max().item())
            sizes = np.array(batch["sizes"][0])
            sizes = np.moveaxis(sizes, 0, -1)
            skimage.io.imsave('sizes.png', sizes)

        # valid_mask = np.array(batch["valid_mask"][0])
        # valid_mask = np.moveaxis(valid_mask, 0, -1)
        # skimage.io.imsave('valid_mask.png', valid_mask)

        input("Press enter to continue...")

        print("Apply online tranform:")
        batch = utils.batch_to_cuda(batch)
        batch = train_online_cuda_transform(batch)
        batch = utils.batch_to_cpu(batch)

        print("image:", batch["image"].shape, batch["image"].min().item(),
              batch["image"].max().item())
        print("gt_polygons_image:", batch["gt_polygons_image"].shape,
              batch["gt_polygons_image"].min().item(),
              batch["gt_polygons_image"].max().item())
        print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape,
              batch["gt_crossfield_angle"].min().item(),
              batch["gt_crossfield_angle"].max().item())
        # print("distances:", batch["distances"].shape, batch["distances"].min().item(), batch["distances"].max().item())
        # print("sizes:", batch["sizes"].shape, batch["sizes"].min().item(), batch["sizes"].max().item())

        # Save output to visualize
        seg = np.array(batch["gt_polygons_image"][0])
        seg = np.moveaxis(seg, 0, -1)
        seg_display = utils.get_seg_display(seg)
        seg_display = (seg_display * 255).astype(np.uint8)
        skimage.io.imsave("gt_seg.png", seg_display)

        im = np.array(batch["image"][0])
        im = np.moveaxis(im, 0, -1)
        skimage.io.imsave('im.png', im)

        gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
        gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
        skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)

        distances = np.array(batch["distances"][0])
        distances = np.moveaxis(distances, 0, -1)
        skimage.io.imsave('distances.png', distances)

        sizes = np.array(batch["sizes"][0])
        sizes = np.moveaxis(sizes, 0, -1)
        skimage.io.imsave('sizes.png', sizes)

        # valid_mask = np.array(batch["valid_mask"][0])
        # valid_mask = np.moveaxis(valid_mask, 0, -1)
        # skimage.io.imsave('valid_mask.png', valid_mask)

        input("Press enter to continue...")
def main():
    # Test using transforms from the frame_field_learning project:
    from frame_field_learning import data_transforms

    config = {
        "data_dir_candidates": [
            "/data/titane/user/nigirard/data", "~/data", "/data",
            "/home/krishna/building-footprints-custom/frameField/data"
        ],
        "dataset_params": {
            "small": True,
            "root_dirname": "mapping_challenge_dataset",
            "seed": 0,
            "train_fraction": 0.75
        },
        "num_workers":
        8,
        "data_aug_params": {
            "enable": False,
            "vflip": True,
            "affine": True,
            "color_jitter": True,
            "device": "cuda"
        }
    }

    # Find data_dir
    data_dir = python_utils.choose_first_existing_path(
        config["data_dir_candidates"])
    if data_dir is None:
        print_utils.print_error("ERROR: Data directory not found!")
        exit()
    else:
        print_utils.print_info("Using data from {}".format(data_dir))
    root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"])

    # --- Transforms: --- #
    # --- pre-processing transform (done once then saved on disk):
    # --- Online transform done on the host (CPU):
    train_online_cpu_transform = data_transforms.get_online_cpu_transform(
        config, augmentations=config["data_aug_params"]["enable"])
    test_online_cpu_transform = data_transforms.get_eval_online_cpu_transform()

    train_online_cuda_transform = data_transforms.get_online_cuda_transform(
        config, augmentations=config["data_aug_params"]["enable"])
    # --- --- #

    dataset = MappingChallenge(
        root_dir,
        transform=test_online_cpu_transform,
        pre_transform=data_transforms.get_offline_transform_patch(),
        fold="train",
        small=config["dataset_params"]["small"],
        pool_size=config["num_workers"])

    print("# --- Sample 0 --- #")
    sample = dataset[0]
    print(sample.keys())

    for key, item in sample.items():
        print("{}: {}".format(key, type(item)))

    print(sample["image"].shape)
    print(len(sample["gt_polygons_image"]))
    print("# --- Samples --- #")
    # for data in tqdm(dataset):
    #     pass

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=10,
        shuffle=True,
        num_workers=config["num_workers"])
    print("# --- Batches --- #")
    for batch in tqdm(data_loader):
        print("Images:")
        print(batch["image_relative_filepath"])
        print(batch["image"].shape)
        print(batch["gt_polygons_image"].shape)

        print("Apply online tranform:")
        batch = utils.batch_to_cuda(batch)
        batch = train_online_cuda_transform(batch)
        batch = utils.batch_to_cpu(batch)

        print(batch["image"].shape)
        print(batch["gt_polygons_image"].shape)

        # Save output to visualize
        seg = np.array(batch["gt_polygons_image"][0])
        seg = np.moveaxis(seg, 0, -1)
        seg_display = utils.get_seg_display(seg)
        seg_display = (seg_display * 255).astype(np.uint8)
        skimage.io.imsave("gt_seg.png", seg_display)
        skimage.io.imsave("gt_seg_edge.png", seg[:, :, 1])

        im = np.array(batch["image"][0])
        im = np.moveaxis(im, 0, -1)
        skimage.io.imsave('im.png', im)

        gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
        gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
        skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)

        distances = np.array(batch["distances"][0])
        distances = np.moveaxis(distances, 0, -1)
        skimage.io.imsave('distances.png', distances)

        sizes = np.array(batch["sizes"][0])
        sizes = np.moveaxis(sizes, 0, -1)
        skimage.io.imsave('sizes.png', sizes)

        # valid_mask = np.array(batch["valid_mask"][0])
        # valid_mask = np.moveaxis(valid_mask, 0, -1)
        # skimage.io.imsave('valid_mask.png', valid_mask)

        input("Press enter to continue...")
    def __init__(self,
                 root: str,
                 fold: str = "train",
                 pre_process: bool = True,
                 tile_filter=None,
                 patch_size: int = None,
                 patch_stride: int = None,
                 pre_transform=None,
                 transform=None,
                 small: bool = False,
                 pool_size: int = 1,
                 raw_dirname: str = "raw",
                 processed_dirname: str = "processed",
                 gt_source: str = "disk",
                 gt_type: str = "npy",
                 gt_dirname: str = "gt_polygons",
                 mask_only: bool = False):
        """

        @param root:
        @param fold:
        @param pre_process: If True, the dataset will be pre-processed first, saving training patches on disk. If False, data will be serve on-the-fly without any patching.
        @param tile_filter: Function to call on tile_info, if returns True, include that tile. If returns False, exclude that tile. Does not affect pre-processing.
        @param patch_size:
        @param patch_stride:
        @param pre_transform:
        @param transform:
        @param small: If True, use a small subset of the dataset (for testing)
        @param pool_size:
        @param processed_dirname:
        @param gt_source: Can be "disk" for annotation that are on disk or "osm" to download from OSM (not implemented)
        @param gt_type: Type of annotation files on disk: can be "npy", "geojson" or "tif"
        @param gt_dirname: Name of directory with annotation files
        @param mask_only: If True, discard the RGB image, sample's "image" field is a single-channel binary mask of the polygons and there is no ground truth segmentation.
            This is to allow learning only the frame field from binary masks in order to polygonize binary masks
        """
        assert gt_source in {"disk", "osm"}, "gt_source should be disk or osm"
        assert gt_type in {
            "npy", "geojson", "tif"
        }, f"gt_type should be npy, geojson or tif, not {gt_type}"
        self.root = root
        self.fold = fold
        self.pre_process = pre_process
        self.tile_filter = tile_filter
        self.patch_size = patch_size
        self.patch_stride = patch_stride
        self.pre_transform = pre_transform
        self.transform = transform
        self.small = small
        if self.small:
            print_utils.print_info(
                "INFO: Using small version of the Inria dataset.")
        self.pool_size = pool_size
        self.raw_dirname = raw_dirname
        self.gt_source = gt_source
        self.gt_type = gt_type
        self.gt_dirname = gt_dirname
        self.mask_only = mask_only

        # Fill default values
        if self.gt_source == "disk":
            print_utils.print_info(
                "INFO: annotations will be loaded from disk")
        elif self.gt_source == "osm":
            print_utils.print_info(
                "INFO: annotations will be downloaded from OSM. "
                "Make sure you have an internet connection to the OSM servers!"
            )

        if self.pre_process:
            # Setup of pre-process
            processed_dirname_extention = f"{processed_dirname}.source_{self.gt_source}.type_{self.gt_type}"
            if self.gt_dirname is not None:
                processed_dirname_extention += f".dirname_{self.gt_dirname}"
            if self.mask_only:
                processed_dirname_extention += f".mask_only_{int(self.mask_only)}"
            processed_dirname_extention += f".patch_size_{int(self.patch_size)}"
            self.processed_dirpath = os.path.join(self.root,
                                                  processed_dirname_extention,
                                                  self.fold)
            self.stats_filepath = os.path.join(
                self.processed_dirpath,
                "stats-small.pt" if self.small else "stats.pt")
            self.processed_flag_filepath = os.path.join(
                self.processed_dirpath,
                "processed_flag-small" if self.small else "processed_flag")

            # Check if dataset has finished pre-processing by checking flag:
            if os.path.exists(self.processed_flag_filepath):
                # Process done, load stats
                self.stats = torch.load(self.stats_filepath)
            else:
                # Pre-process not finished, launch it:
                tile_info_list = self.get_tile_info_list(tile_filter=None)
                self.stats = self.process(tile_info_list)
                # Save stats
                torch.save(self.stats, self.stats_filepath)
                # Mark dataset as processed with flag
                pathlib.Path(self.processed_flag_filepath).touch()

            # Get processed_relative_paths with filter
            tile_info_list = self.get_tile_info_list(
                tile_filter=self.tile_filter)
            self.processed_relative_paths = self.get_processed_relative_paths(
                tile_info_list)
        else:
            # Setup data sample list
            self.tile_info_list = self.get_tile_info_list(
                tile_filter=self.tile_filter)