示例#1
0
def evaluate(model_path: Path, datasets: typing.List[Datasets], output_folder: Path, find_mistakes: bool = False, include_heading: bool = False) -> str:
    """Evaluate a model, returning the results as CSV.

    Args:
        model_path (Path): path to the model folder containing the YAML file and the saved weights
        datasets (typing.List[Datasets]): the datasets to evaluate on
        output_folder (Path): output folder for the mistake images (if applicable)
        find_mistakes (bool, optional): whether to output all mistakes as images to the output folder. Defaults to False.
        include_heading (bool, optional): whether to include a heading in the CSV output. Defaults to False.

    Raises:
        ValueError: if the YAML config file is missing

    Returns:
        str: the CSV string
    """
    model_name = model_path.stem
    config_file = model_path.parent / f"{model_name}.yaml"
    if not config_file.exists():
        raise ValueError("config file missing")
    cfg = CN.load_yaml_with_base(config_file)
    model = torch.load(model_path, map_location=DEVICE)
    model = device(model)
    model.eval()
    datasets = {mode: build_dataset(cfg, mode)
                for mode in datasets}
    classes = next(iter(datasets.values())).classes

    csv = []
    if include_heading:
        csv.append(_csv_heading(classes))
    for mode, dataset in datasets.items():
        # Load dataset
        loader = build_data_loader(cfg, dataset, mode)
        # Compute statistics over whole dataset
        agg = StatsAggregator(classes)
        for images, labels in device(loader):
            predictions = model(images)
            agg.add_batch(predictions, labels, **(dict(inputs=images)
                                                  if find_mistakes else dict()))

        csv.append(_csv(model, agg, model_name, mode))
        if find_mistakes:
            groundtruth, mistakes = zip(*sorted(agg.mistakes,
                                                key=lambda x: x[0]))
            imgs = torch.tensor(mistakes).permute((0, 2, 3, 1))
            imgs = unnormalize(imgs).permute((0, 3, 1, 2))
            img = torchvision.utils.make_grid(imgs, pad_value=1, nrow=4)
            img = img.numpy().transpose((1, 2, 0)) * 255
            img = Image.fromarray(img.astype(np.uint8))
            mistakes_file = output_folder / \
                f"{model_name}_{mode.value}_mistakes.png"
            logger.info(f"Writing mistakes to {mistakes_file}")
            img.save(mistakes_file)
            groundtruth_file = output_folder / \
                f"{model_name}_{mode.value}_groundtruth.csv"
            with groundtruth_file.open("w") as f:
                f.write(",".join(map(str, groundtruth)))
    return "\n".join(csv)
示例#2
0
def aggregator() -> StatsAggregator:
    agg = StatsAggregator(["a", "b"])
    a_output = np.array([.9, .1, .8, .2, .9, .9, .9, .2])
    b_output = 1 - a_output
    outputs = torch.tensor(np.stack([a_output, b_output], axis=-1))
    labels = torch.tensor([0, 0, 0, 1, 0, 0, 1, 0])
    # predicted:          [0, 1, 0, 1, 0, 0, 0, 1]
    agg.add_batch(outputs, labels)
    return agg