コード例 #1
0
    def test_mixed_nan_to_num(self):

        array = np.array([1, np.nan, 1, np.nan])

        # case 1 - the array has less nans than would be triggered
        num_array = BaseEngineer.maxed_nan_to_num(array, nan=0, max_ratio=0.75)
        assert (num_array == np.array([1, 0, 1, 0])).all()

        # case 2 - the array has more nans than would be triggered
        num_array = BaseEngineer.maxed_nan_to_num(array, nan=0, max_ratio=0.25)
        assert num_array is None
コード例 #2
0
ファイル: utils.py プロジェクト: nasaharvest/crop-maml
def tif_to_np(
    path_to_dataset: Path,
    add_ndvi: bool,
    add_ndwi: bool,
    nan: float,
    normalizing_dict: Optional[Dict[str, np.ndarray]],
    days_per_timestep: int,
) -> TestInstance:

    _, start_date, _ = cast(
        Tuple[str, datetime, datetime],
        BaseEngineer.process_filename(path_to_dataset.name,
                                      include_extended_filenames=True),
    )

    x = BaseEngineer.load_tif(path_to_dataset,
                              days_per_timestep=days_per_timestep,
                              start_date=start_date)

    lon, lat = np.meshgrid(x.x.values, x.y.values)
    flat_lat, flat_lon = (
        np.squeeze(lat.reshape(-1, 1), -1),
        np.squeeze(lon.reshape(-1, 1), -1),
    )

    x_np = x.values
    x_np = x_np.reshape(x_np.shape[0], x_np.shape[1],
                        x_np.shape[2] * x_np.shape[3])
    x_np = np.moveaxis(x_np, -1, 0)

    if add_ndvi:
        x_np = BaseEngineer.calculate_ndvi(x_np, num_dims=3)
    if add_ndwi:
        x_np = BaseEngineer.calculate_ndwi(x_np, num_dims=3)

    x_np = BaseEngineer.maxed_nan_to_num(x_np, nan=nan)

    if normalizing_dict is not None:
        x_np = (x_np - normalizing_dict["mean"]) / normalizing_dict["std"]

    return TestInstance(x=x_np, lat=flat_lat, lon=flat_lon)
コード例 #3
0
    def test_randomly_selected_latlons_not_labels(self, monkeypatch):
        def return_const_latlons(array, size, replace):
            return array[:size]

        monkeypatch.setattr(np.random, "choice", return_const_latlons)

        lat_array = np.array([1, 2, 3])
        lon_array = np.array([4, 5, 6])

        # test case 1 - both the lat and lon don't match - should skip
        label_lat, label_lon = 1, 4
        random_lat, random_lon = BaseEngineer.randomly_select_latlon(
            lat_array, lon_array, label_lat, label_lon)
        assert (random_lat == 2) and (random_lon == 5)

        # test case 2 - one of the lat and lon do match, but the other doesn't
        # - shouldn't skip
        label_lat, label_lon = 1, 2
        random_lat, random_lon = BaseEngineer.randomly_select_latlon(
            lat_array, lon_array, label_lat, label_lon)
        assert (random_lat == 1) and (random_lon == 4)
コード例 #4
0
    def test_find_nearest(self):
        array = np.array([1, 2, 3, 4, 5])

        target = 1.1

        assert BaseEngineer.find_nearest(array, target) == 1
コード例 #5
0
def plot_results(model_preds: xr.Dataset,
                 tci_path: Path,
                 savepath: Path,
                 prefix: str = "") -> None:

    multi_output = len(model_preds.data_vars) > 1

    tci = sentinel_as_tci(
        BaseEngineer.load_tif(tci_path,
                              start_date=datetime(2020, 1, 1),
                              days_per_timestep=30),
        scale=False,
    ).isel(time=-1)

    tci = tci.sortby("x").sortby("y")
    model_preds = model_preds.sortby("lat").sortby("lon")

    plt.clf()
    fig, ax = plt.subplots(1,
                           3,
                           figsize=(20, 7.5),
                           subplot_kw={"projection": ccrs.PlateCarree()})

    fig.suptitle(
        f"Model results for tile with bottom left corner:"
        f"\nat latitude {float(model_preds.lat.min())}"
        f"\n and longitude {float(model_preds.lon.min())}",
        fontsize=15,
    )
    # ax 1 - original
    img_extent_1 = (tci.x.min(), tci.x.max(), tci.y.min(), tci.y.max())
    img = np.clip(np.moveaxis(tci.values, 0, -1), 0, 1)

    ax[0].set_title("True colour image")
    ax[0].imshow(img,
                 origin="upper",
                 extent=img_extent_1,
                 transform=ccrs.PlateCarree())

    args_dict = {
        "origin": "upper",
        "extent": img_extent_1,
        "transform": ccrs.PlateCarree(),
    }

    if multi_output:
        mask = np.argmax(model_preds.to_array().values, axis=0)

        # currently, we have 10 classes (at most). It seems unlikely we will go
        # above 20
        args_dict["cmap"] = plt.cm.get_cmap("tab20",
                                            len(model_preds.data_vars))
    else:
        mask = model_preds.prediction_0
        args_dict.update({"vmin": 0, "vmax": 1})

    # ax 2 - mask
    ax[1].set_title("Mask")
    im = ax[1].imshow(mask, **args_dict)

    # finally, all together
    ax[2].set_title("Mask on top of the true colour image")
    ax[2].imshow(img,
                 origin="upper",
                 extent=img_extent_1,
                 transform=ccrs.PlateCarree())

    args_dict["alpha"] = 0.3
    if not multi_output:
        mask = mask > 0.5
    ax[2].imshow(mask, **args_dict)

    colorbar_args = {
        "ax": ax.ravel().tolist(),
    }

    if multi_output:
        # This function formatter will replace integers with target names
        formatter = plt.FuncFormatter(
            lambda val, loc: list(model_preds.data_vars)[val])
        colorbar_args.update({
            "ticks": range(len(model_preds.data_vars)),
            "format": formatter
        })

    # We must be sure to specify the ticks matching our target names
    fig.colorbar(im, **colorbar_args)

    plt.savefig(savepath / f"results_{prefix}{tci_path.name}.png",
                bbox_inches="tight",
                dpi=300)
    plt.close()