def get_batch_test_set(config, reserve_batches):
    out_dir = config["output"]
    num_classes = len(utils.load_classes(config["class_list"]))
    batch_sets = sorted(glob.glob(f"{out_dir}/sample*.txt"),
                        key=utils.get_sample)

    test_imgs = list()
    batches_removed = 0
    for batch_set in reversed(batch_sets):
        imgs = utils.get_lines(batch_set)
        if len(imgs) < config["sampling_batch"] or reserve_batches != 0:
            test_imgs += imgs
            batches_removed += 1
            if not (len(imgs) < config["sampling_batch"]):
                reserve_batches -= 1
    return LabeledSet(test_imgs, num_classes), batches_removed
def get_test_sets(config, prefix):
    out_dir = config["output"]
    num_classes = len(utils.load_classes(config["class_list"]))
    epoch_splits = utils.get_epoch_splits(config, prefix)

    # Initial test set
    init_test_set = f"{out_dir}/init_test.txt"
    init_test_folder = LabeledSet(init_test_set, num_classes)

    # Only data from the (combined) iteration test sets (75% sampling + 25% seen data)
    iter_test_sets = [
        f"{out_dir}/{prefix}{i}_test.txt" for i in range(len(epoch_splits))
    ]
    iter_img_files = list()
    for file in iter_test_sets:
        iter_img_files += utils.get_lines(file)
    all_iter_sets = LabeledSet(iter_img_files, num_classes)

    # Test sets filtered for only sampled images
    sampled_imgs = [
        img for img in iter_img_files if config["sample_set"] in img
    ]
    sample_test = LabeledSet(sampled_imgs, num_classes)

    # Data from all test sets
    all_test = LabeledSet(sampled_imgs, num_classes)
    all_test += init_test_folder

    test_sets = {
        "init": init_test_folder,
        "all_iter": all_iter_sets,
        "sample": sample_test,
        "all": all_test,
    }

    if prefix != "init":
        for i in range(len(epoch_splits)):
            test_sets[f"cur_iter{i + 1}"] = LabeledSet(iter_test_sets[i],
                                                       num_classes)

    return test_sets
def benchmark_next_batch(prefix, config, opt):
    """See initial training performance on batch splits."""
    out_dir = config["output"]
    num_classes = len(utils.load_classes(config["class_list"]))
    batch_sets = sorted(glob.glob(f"{out_dir}/sample*.txt"),
                        key=utils.get_sample)

    epoch_splits = utils.get_epoch_splits(config, prefix, True)
    if prefix == "init":
        epoch_splits *= len(batch_sets)

    batch_folders = [
        LabeledSet(batch_set, num_classes) for batch_set in batch_sets
    ]

    def get_filename(i, end_epoch):
        filename = f"{out_dir}/{prefix}{i}_benchmark_"
        filename += "roll_" if opt.roll_avg else "avg_"
        filename += f"1_{end_epoch}.csv"
        return filename

    benchmark_batch_splits(prefix, batch_folders, epoch_splits, get_filename,
                           config, opt)
def sample_retrain(
    sample_method,
    batches,
    config,
    last_epoch,
    seen_images,
    label_func,
    device=None,
):
    """Run the sampling and retraining pipeline for a particular sampling function."""
    name, _ = sample_method
    classes = utils.load_classes(config["class_list"])
    seen_images = copy.deepcopy(seen_images)
    for i, sample_folder in enumerate(batches):
        sample_folder.label(classes, label_func)
        sample_labeled = LabeledSet(
            sample_folder.imgs,
            len(classes),
            config["img_size"],
        )

        sample_filename = f"{config['output']}/{name}{i}_sample_{last_epoch}.txt"
        if os.path.exists(sample_filename):
            print("Loading existing samples")
            retrain_files = open(sample_filename, "r").read().split("\n")

        else:
            retrain_files = benchmark_sample(sample_method, sample_labeled,
                                             config, i, last_epoch)

            # When deploying at the edge, this would be where data is
            # sent from nodes to the Beehive, along with the benchmark file
            with open(sample_filename, "w+") as out:
                out.write("\n".join(retrain_files))

        # Receive raw sampled data in the cloud
        # This process simulates manually labeling/verifying all inferences
        retrain_obj = LabeledSet(retrain_files,
                                 len(classes),
                                 config["img_size"],
                                 prefix=f"{name}{i}")

        new_splits_made = retrain_obj.load_or_split(
            config["output"],
            config["train_sample"],
            config["valid_sample"],
            save=False,
            sample_dir=config["sample_set"],
        )

        if new_splits_made:
            # If reloaded, splits have old images already incorporated
            for set_name in retrain_obj.sets:
                # Calculate proportion of old examples needed
                number_desired = (1 / config["retrain_new"] - 1) * len(
                    getattr(retrain_obj, set_name))
                if round(number_desired) == 0:
                    continue
                print(set_name, number_desired)
                extra_images = getattr(seen_images, set_name).split_batch(
                    round(number_desired))[0]
                orig_set = getattr(retrain_obj, set_name)
                orig_set += extra_images

        seen_images += retrain_obj

        retrain_obj.save_splits(config["output"])
        retrain_obj.train.augment(config["images_per_class"])

        config["start_epoch"] = last_epoch + 1
        checkpoint = utils.find_checkpoint(config, name, last_epoch)
        last_epoch = train.train(retrain_obj,
                                 config,
                                 checkpoint,
                                 device=device)
from retrain import retrain, utils

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        required=True,
                        help="configuration for retraining")
    parser.add_argument(
        "--reload_baseline",
        default=None,
        help="bypass initial training with a checkpoint",
    )
    opt = parser.parse_args()

    config = utils.parse_retrain_config(opt.config)
    classes = utils.load_classes(config["class_list"])

    init_images = LabeledSet(config["initial_set"],
                             len(classes),
                             config["img_size"],
                             prefix="init")

    init_images.load_or_split(config["output"], config["train_init"],
                              config["valid_init"])

    # Run initial training
    if opt.reload_baseline is None:
        init_end_epoch = train_initial(init_images, config)
        print(f"Initial training ended on epoch {init_end_epoch}")
    else:
        init_end_epoch = utils.get_epoch(opt.reload_baseline)
Exemple #6
0
def train(img_folder, opt, load_weights=None, device=None):
    """Trains a given image set, with an early stop.

    Args:
        img_folder (ImageFolder): Data wrapper that must have been split into train, test,
            and validation sets.
        opt (dict): Configuration dictionary with hyperparameters for training
        load_weights (str): Path of initial weights, if training is resumed from a checkpoint.
        device (str): PyTorch CUDA string of a GPU device. Looks for available devices if
            none is provided.
    Returns:
        last_epoch (int): the epoch number where training was ended
    """
    os.makedirs(opt["checkpoints"], exist_ok=True)
    os.makedirs(opt["output"], exist_ok=True)

    model = models.get_train_model(opt, device)

    print(f"Using {model.device} for training")
    yoloutils.clear_vram()

    # Initiate model
    model.apply(yoloutils.weights_init_normal)

    if load_weights is not None:
        model.load_state_dict(
            torch.load(load_weights, map_location=model.device))

    class_names = utils.load_classes(opt["class_list"])

    # Get dataloader
    dataset = img_folder.train.to_dataset(multiscale=bool(opt["multiscale"]), )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt["batch_size"],
        shuffle=True,
        num_workers=opt["n_cpu"],
        pin_memory=True,
        collate_fn=dataset.collate_fn,
    )

    optimizer = torch.optim.Adam(model.parameters())

    metrics = [
        "grid_size",
        "loss",
        "x",
        "y",
        "w",
        "h",
        "conf",
        "cls",
        "cls_acc",
        "recall50",
        "recall75",
        "precision",
        "conf_obj",
        "conf_noobj",
    ]

    # Limit logging rate of batch metrics
    log_freq = min(
        len(dataloader),
        opt["logs_per_epoch"] if "logs_per_epoch" in opt.keys() else 50)
    log_interval = int(len(dataloader) / log_freq)
    logger = Logger(opt["log"], img_folder.prefix, log_interval)

    successive_stops = 0
    prev_strip_loss = float("inf")

    end_epoch = opt["start_epoch"] + opt["max_epochs"]
    last_epoch = opt["start_epoch"]

    for epoch in range(opt["start_epoch"], end_epoch):
        last_epoch = epoch
        model.train()

        ckpt_path = f"{opt['checkpoints']}/{img_folder.prefix}_ckpt_{epoch}.pth"

        if not os.path.exists(ckpt_path):
            train_epoch(dataloader, epoch, end_epoch, model, optimizer,
                        metrics, logger, opt)
            if epoch % opt["checkpoint_interval"] == 0:
                save_ckpt(model, img_folder.prefix, epoch, opt["checkpoints"])

        else:
            model.load_state_dict(torch.load(ckpt_path))

        # Use UP criteria for early stop
        if bool(opt["early_stop"]) and (epoch == opt["start_epoch"]
                                        or epoch % opt["strip_len"] == 0):
            print(
                f"\n---Evaluating validation set on epoch {epoch} for early stop---"
            )

            valid_results = evaluate.get_results(model, img_folder.valid, opt,
                                                 class_names, logger, epoch)

            if valid_results["val_loss"] > prev_strip_loss:
                successive_stops += 1
            else:
                successive_stops = 0
            print(f"Previous loss: {prev_strip_loss}")
            print(f"Current loss: {valid_results['val_loss']}")

            prev_strip_loss = valid_results["val_loss"]

            if successive_stops == opt["successions"]:
                print(f"Early stop at epoch {epoch}")
                break

        if epoch % opt["evaluation_interval"] == 0:
            print(f"\n---Evaluating test set on epoch {epoch}---")
            evaluate.get_results(model, img_folder.test, opt, class_names,
                                 logger, epoch)

    yoloutils.clear_vram()
    return last_epoch
def make_results_df(config, img_folder, detections_by_img, total_epochs):
    metrics = [
        "file",
        "actual",
        "detected",
        "conf",
        "conf_std",
        "hit",
    ]

    results = pd.DataFrame(columns=metrics)
    classes = utils.load_classes(config["class_list"])

    for path, detections in detections_by_img.items():
        ground_truths = img_folder.get_classes(utils.get_label_path(path))
        detection_pairs = list()
        if detections is not None:
            region_detections, regions_std = yoloutils.group_average_bb(
                detections, total_epochs, config["iou_thres"])

            # evaluate.save_image(region_detections, path, config, classes)
            if len(region_detections) == 1:
                detected_class = int(region_detections.numpy()[0][-1])
                if detected_class in ground_truths:
                    label = detected_class
                elif len(ground_truths) == 1:
                    label = ground_truths[0]
                else:
                    label = None
                detection_pairs = [(label, region_detections[0])]
            else:
                test_img = LabeledSet([path], len(classes))
                detection_pairs = evaluate.match_detections(
                    test_img, region_detections.unsqueeze(0), config)

        for (truth, box) in detection_pairs:
            if box is None:
                continue
            obj_conf, class_conf, pred_class = box.numpy()[4:]
            obj_std, class_std = regions_std[round(float(class_conf), 3)]

            row = {
                "file": path,
                "detected": classes[int(pred_class)],
                "actual": classes[int(truth)] if truth is not None else "",
                "conf": obj_conf * class_conf,
                "conf_std": math.sqrt(obj_std**2 + class_std**2),
            }
            row["hit"] = row["actual"] == row["detected"]

            results = results.append(row, ignore_index=True)

            if truth is not None:
                ground_truths.remove(int(truth))

        # Add rows for those missing detections
        for truth in ground_truths:
            row = {
                "file": path,
                "detected": "",
                "actual": classes[int(truth)],
                "conf": 0.0,
                "hit": False,
                "conf_std": 0.0,
            }

            results = results.append(row, ignore_index=True)
    return results
def simple_benchmark_avg(img_folder,
                         prefix,
                         start,
                         end,
                         total_epochs,
                         config,
                         roll=False):
    """Deprecated version of benchmark averaging, meant for single object
    detection within an image. Used for a fair comparison baseline on old models
    """

    loader = DataLoader(
        img_folder,
        batch_size=1,
        shuffle=False,
        num_workers=config["n_cpu"],
    )

    results = pd.DataFrame(
        columns=["file", "confs", "actual", "detected", "conf", "hit"])
    results.set_index("file")

    classes = utils.load_classes(config["class_list"])

    if roll:
        checkpoints_i = list(range(max(1, end - total_epochs + 1), end + 1))
    else:
        checkpoints_i = list(
            sorted(
                set(
                    np.linspace(start,
                                end,
                                total_epochs,
                                dtype=np.dtype(np.int16)))))

    single = total_epochs == 1

    if not single:
        print("Benchmarking on epochs", checkpoints_i)

    for n in tqdm(checkpoints_i, "Benchmarking epochs", disable=single):
        ckpt = get_checkpoint(config["checkpoints"], prefix, n)

        model_def = yoloutils.parse_model_config(config["model_config"])
        model = models.get_eval_model(model_def, config["img_size"], ckpt)

        for (img_paths, input_imgs) in loader:
            path = img_paths[0]
            if path not in results.file:
                actual_class = classes[img_folder.get_classes(
                    utils.get_label_path(path))[0]]
                results.loc[path] = [
                    path, dict(), actual_class, None, None, None
                ]

            detections = evaluate.detect(input_imgs, config["conf_thres"],
                                         model)

            confs = results.loc[path]["confs"]

            for detection in detections:
                if detection is None:
                    continue
                (_, _, _, _, _, cls_conf, cls_pred) = detection.numpy()[0]

                if cls_pred not in confs.keys():
                    confs[cls_pred] = [cls_conf]

                else:
                    confs[cls_pred].append(cls_conf)

    for _, row in results.iterrows():
        best_class = None
        best_conf = float("-inf")

        for class_name, confs in row["confs"].items():
            avg_conf = sum(confs) / len(checkpoints_i)

            if avg_conf > best_conf:
                best_conf = avg_conf
                best_class = class_name

        if best_class is not None:
            row["detected"] = classes[int(best_class)]
            row["conf"] = best_conf
            row["hit"] = row["actual"] == row["detected"]
        else:
            row["detected"] = ""
            row["conf"] = 0.0
            row["hit"] = False

    return results