示例#1
0
def test_compute_mse_error(tmp_path: Path) -> None:
    data = ChunkedDataset(path="./l5kit/tests/artefacts/single_scene.zarr")
    data.open()
    export_zarr_to_ground_truth_csv(data, str(tmp_path / "gt1.csv"), 0, 50, 0.5)
    data.open()  # avoid double select_agents
    export_zarr_to_ground_truth_csv(data, str(tmp_path / "gt2.csv"), 0, 50, 0.5)
    err = compute_mse_error_csv(str(tmp_path / "gt1.csv"), str(tmp_path / "gt2.csv"))
    assert np.all(err == 0.0)

    data_fake = ChunkedDataset(str(tmp_path))
    data_fake.scenes = np.asarray(data.scenes).copy()
    data_fake.frames = np.asarray(data.frames).copy()
    data_fake.agents = np.asarray(data.agents).copy()
    data_fake.root = data.root
    data_fake.agents["centroid"] += np.random.rand(*data_fake.agents["centroid"].shape)

    export_zarr_to_ground_truth_csv(data_fake, str(tmp_path / "gt3.csv"), 0, 50, 0.5)
    err = compute_mse_error_csv(str(tmp_path / "gt1.csv"), str(tmp_path / "gt3.csv"))
    assert np.any(err > 0.0)

    # test invalid conf by removing lines in gt1
    with open(str(tmp_path / "gt4.csv"), "w") as fp:
        lines = open(str(tmp_path / "gt1.csv")).readlines()
        fp.writelines(lines[:-10])

    with pytest.raises(ValueError):
        compute_mse_error_csv(str(tmp_path / "gt1.csv"), str(tmp_path / "gt4.csv"))
示例#2
0
def test_compute_mse_error(tmp_path: Path,
                           zarr_dataset: ChunkedDataset) -> None:
    export_zarr_to_ground_truth_csv(zarr_dataset, str(tmp_path / "gt1.csv"),
                                    10, 50, 0.5)
    export_zarr_to_ground_truth_csv(zarr_dataset, str(tmp_path / "gt2.csv"),
                                    10, 50, 0.5)
    err = compute_mse_error_csv(str(tmp_path / "gt1.csv"),
                                str(tmp_path / "gt2.csv"))
    assert np.all(err == 0.0)

    data_fake = ChunkedDataset(str(tmp_path))
    data_fake.scenes = np.asarray(zarr_dataset.scenes).copy()
    data_fake.frames = np.asarray(zarr_dataset.frames).copy()
    data_fake.agents = np.asarray(zarr_dataset.agents).copy()
    data_fake.agents["centroid"] += np.random.rand(
        *data_fake.agents["centroid"].shape) * 1e-2

    export_zarr_to_ground_truth_csv(data_fake, str(tmp_path / "gt3.csv"), 10,
                                    50, 0.5)
    err = compute_mse_error_csv(str(tmp_path / "gt1.csv"),
                                str(tmp_path / "gt3.csv"))
    assert np.any(err > 0.0)

    # test invalid conf by removing lines in gt1
    with open(str(tmp_path / "gt4.csv"), "w") as fp:
        lines = open(str(tmp_path / "gt1.csv")).readlines()
        fp.writelines(lines[:-10])

    with pytest.raises(ValueError):
        compute_mse_error_csv(str(tmp_path / "gt1.csv"),
                              str(tmp_path / "gt4.csv"))