Example #1
0
def perform_custom_validation(db,
                              fpaths,
                              dates,
                              start_row,
                              start_col,
                              comp_type,
                              mask=None):
    validate_file_path_list(fpaths)
    if len(dates) != len(fpaths):
        raise RuntimeError("Input paths size does not match input dates size")
    if len(fpaths) == 0:
        print("No data")
        return
    pf = WMOValidationPointFetcher(db, retrieval_type=comp_type)
    data = load_npy_files(fpaths)
    _verify_grids_are_homogenous_shape(data)

    nr, nc = data[0].shape
    xj, xi = np.meshgrid(range(start_col, start_col + nc),
                         range(start_row, start_row + nr))
    if mask is not None:
        mask = mask[xi, xj]
    lon, lat = eg.v1_get_full_grid_lonlat(eg.ML)
    lon = lon[xi, xj]
    lat = lat[xi, xj]
    bounds = [lon.min(), lon.max(), lat.min(), lat.max()]
    pg = PointsGridder(lon, lat, invalid_mask=mask)
    date_to_grid = {d: g for d, g in zip(dates, data)}
    results = perform_bounded_validation(date_to_grid, pf, pg, bounds)
    output_validation_stats_grouped_by_month(results, [COL_DATE, COL_SCORE])
Example #2
0
def main(args):
    if args.start_year > args.end_year:
        raise ValueError("Start year must be less than or equal to end_year")
    year_str = get_year_str(args.start_year, args.end_year)
    if args.region is not None:
        transform = REGION_TO_TRANS[args.region]
        land_mask = ~transform(np.load("../data/masks/ft_esdr_water_mask.npy"))
        lon, lat = [transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)]
    else:
        assert args.lon is not None and args.lat is not None
        land_mask = ~args.water
        lon = args.lon
        lat = args.lat
    # TODO: add option for AM/PM
    ret_type = RETRIEVAL_MIN
    dates = load_dates(
        args.dates or f"../data/cleaned/date_map-{year_str}-{args.region}.csv")
    # TODO: add option for AM/PM
    db_path = args.db_path or "../data/dbs/wmo_gsod.db"
    aws_data = get_aws_data(
        dates,
        db_path,
        land_mask,
        lon,
        lat,
        ret_type,
    )
    # TODO: add option for AM/PM
    out_file = os.path.join(args.out_dir,
                            f"aws_data-AM-{year_str}-{args.region}.pkl")
    print(f"Saving data to '{out_file}'")
    persist_data_object(aws_data, out_file, overwrite=True)
Example #3
0
def main(args):
    cfile = os.path.join(args.target_dir, "config.yaml")
    if args.config is not None:
        # Already validated by parser
        cfile = args.config
    else:
        try:
            utils.validate_file_path(cfile)
        except IOError:
            cfile = os.path.join(args.target_dir, "config")
            utils.validate_file_path(cfile)
    config = load_config(cfile)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = REGION_TO_TRANS[config.region]
    lon, lat = [transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)]
    model_path = os.path.join(args.target_dir, "model.pt")
    utils.validate_file_path(model_path)
    model_class = UNet
    if args.legacy_model:
        print("Using legacy model")
        model_class = UNetLegacy
    model = model_class(
        in_chan=config.in_chan,
        n_classes=config.n_classes,
        depth=config.depth,
        base_filter_bank_size=config.base_filters,
        skip=config.skips,
        bndry_dropout=config.bndry_dropout,
        bndry_dropout_p=config.bndry_dropout_p,
    )
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model = model.to(device)
    model_dict = torch.load(model_path)
    model_load(model, model_dict)
    model.eval()
    input_ds = build_input_dataset_form_config(config, False, lat)
    batch_size = args.batch_size if args.batch_size > 0 else config.batch_size
    input_dl = torch.utils.data.DataLoader(input_ds,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           drop_last=False)
    preds, probs = get_predictions(
        input_dl,
        model,
        transform(np.load("../data/masks/ft_esdr_water_mask.npy")),
        LABEL_OTHER,
        device,
        config,
    )
    # TODO
    pred_path = os.path.join(args.target_dir, "pred.npy")
    prob_path = os.path.join(args.target_dir, "prob.npy")
    print(f"Saving probabilities: '{prob_path}'")
    np.save(prob_path, probs)
    if args.save_predictions:
        print(f"Saving predictions: '{pred_path}'")
        np.save(preds, pred_path)
Example #4
0
def load_data(files, proj):
    if not files:
        raise ValueError("files must not be empty")
    dates = [_parse_date_from_fname(f) for f in files]
    grids = [tbmod.load_tb_file(f, proj) for f in files]
    dates, grids, missing_mask = _fill_gaps(dates, grids)
    lon, lat = eg.v1_get_full_grid_lonlat(proj)
    x, y = eg.v1_lonlat_to_meters(lon, lat, proj)
    lon = lon[0]
    lat = lat[:, 0]
    x = x[0]
    y = y[:, 0]
    return dates, lon, lat, x, y, grids, missing_mask
Example #5
0
def main(args):
    with rio.open(args.meta_tif) as tif:
        proj = SNOW_4K_PROJ
        year = int(os.path.basename(os.path.dirname(args.in_dir)))
        ndays = 365 + int(calendar.isleap(year))
        files = sorted(glob.glob(os.path.join(args.in_dir, "*.asc.gz")))
        idxs = [_parse_index(f) for f in files]
        lon, lat = [
            NH_VIEW_TRANS(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)
        ]
        data = np.full((ndays, *lon.shape), SNOW_MISSING_VALUE, dtype=np.int8)
        fill_data_array(data, files, idxs, tif, proj, lon, lat)
    print(f"Saving data to disk: '{args.out_file}'")
    np.save(args.out_file, data)
Example #6
0
def perform_validation_on_ft_esdr(db, fpaths, mask=None):
    for f in fpaths:
        validate_file_path(f)
    pf = WMOValidationPointFetcher(db)
    pg = PointsGridder(*eg.v1_get_full_grid_lonlat(eg.ML), invalid_mask=mask)
    data = load_ft_esdr_data_from_files(fpaths)
    dates_am = [d.dt for d in data if d.am_grid is not None]
    dates_pm = [d.dt for d in data if d.pm_grid is not None]
    grids_am = [d.am_grid for d in data if d.am_grid is not None]
    grids_pm = [d.pm_grid for d in data if d.pm_grid is not None]

    grids_am = {k: v for k, v in zip(dates_am, grids_am)}
    grids_pm = {k: v for k, v in zip(dates_pm, grids_pm)}
    perform_default_am_pm_validation(grids_am, grids_pm, pf, pg)
Example #7
0
def main(args):
    data = np.load(args.in_file).astype(args.type)
    if len(data.shape) > 3 or len(data.shape) < 2:
        raise UnsupportedNumberDimsError(
            f"Number of dims must be 2 or 3. Got {len(data.shape)}."
        )
    elif len(data.shape) == 3:
        nbands = data.shape[0]
    else:
        nbands = 1
        # Add extra axis for iteration purposes
        data = np.expand_dims(data, axis=0)
    trans = REGION_TO_TRANS[args.region]
    if args.mask is not None:
        mask = trans(np.load(args.mask).astype(bool))
        data[..., mask] = args.missing_value
    crs = eg.GRID_NAME_TO_V1_PROJ[eg.ML]
    x, y = [
        trans(xi)
        for xi in eg.v1_lonlat_to_meters(*eg.v1_get_full_grid_lonlat(eg.ML))
    ]
    x = x[0]
    y = y[:, 0]
    xres = (x[-1] - x[0]) / len(x)
    yres = (y[-1] - y[0]) / len(y)
    t = Affine.translation(x[0], y[0]) * Affine.scale(xres, yres)
    ds = rio.open(
        args.out_file,
        "w",
        driver="GTiff",
        height=data.shape[1],
        width=data.shape[2],
        count=nbands,
        dtype=args.type,
        crs=crs.srs,
        transform=t,
        compress="lzw",
        nodata=args.missing_value,
    )
    for i, band in enumerate(data):
        ds.write(band, i + 1)
    ds.close()
Example #8
0
def worker(args):
    (
        pred_path,
        am_pm,
        region,
        mask_code,
        out_path,
    ) = args
    db = get_db_session("../data/dbs/wmo_gsod.db")
    dates = date_range(dt.date(1988, 1, 2), dt.date(2019, 1, 1))
    trans = REGION_TO_TRANS[region]
    lon, lat = [trans(x) for x in eg.v1_get_full_grid_lonlat(eg.ML)]
    land = trans(np.load("../data/masks/ft_esdr_land_mask.npy"))
    water = ~land
    non_cc_mask = trans(
        np.load("../data/masks/ft_esdr_non_cold_constrained_mask.npy"))
    invalid = non_cc_mask | water
    cc_mask = ~invalid
    inv_cc_mask = land & ~cc_mask
    mask = None
    if mask_code == CC_MASK:
        mask = cc_mask
    elif mask_code == LAND_MASK:
        mask = land
    else:
        mask = inv_cc_mask

    pred = trans(np.load(pred_path))
    df = validate_against_aws_db(pred,
                                 db,
                                 dates,
                                 lon,
                                 lat,
                                 mask,
                                 am_pm,
                                 progress=False)
    df.to_csv(out_path)
Example #9
0
def main(args):
    db = get_db_session(args.db_path)
    out_dir = args.out_dir
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    interp_plot_dir = os.path.join(out_dir, "interp")
    if not os.path.isdir(interp_plot_dir):
        os.makedirs(interp_plot_dir)
    dist_plot_dir = os.path.join(out_dir, "dist")
    if not os.path.isdir(dist_plot_dir):
        os.makedirs(dist_plot_dir)
    start_date = dt.date(2000, 7, 1)
    end_date = dt.date(2001, 7, 1)
    dt_delta = dt.timedelta(1)
    dates, records = load_records(db, start_date, end_date, dt_delta)
    lons, lats = eg.v1_get_full_grid_lonlat(eg.ML)
    x, y = eg.v1_lonlat_to_meters(lons, lats, eg.ML)
    interp_grids = np.zeros((len(records), *x.shape))
    dist_grids = np.zeros_like(interp_grids)
    fill_grids(x, y, records, interp_grids, dist_grids)
    dmin = dist_grids.min()
    dmax = dist_grids.max()

    for i in range(len(dates)):
        plot(
            dates[i],
            lons,
            lats,
            interp_grids[i],
            dist_grids[i],
            interp_plot_dir,
            dist_plot_dir,
            dmin,
            dmax,
        )
    db.close()
Example #10
0
def main(args):
    lon, lat = eg.v1_get_full_grid_lonlat(eg.ML)
    dates = get_dates_for_year(args.year)
    rads = get_daily_radiation_parallel(dates, lon, lat, args.workers)
    print(f"Saving to: {args.out_file}")
    np.save(args.out_file, rads)
Example #11
0
TYPE_PM = "PM"
# Composite
TYPE_CO = "CO"

OTHER = -1
FROZEN = 0
THAWED = 1

FT_ESDR_FROZEN = 0
FT_ESDR_THAWED = 1
# Frozen in AM, thawed in PM
FT_ESDR_TRANSITIONAL = 2
# Thawed in AM, frozen in PM
FT_ESDR_INV_TRANSITIONAL = 3

_EASE_LON, _EASE_LAT = eg.v1_get_full_grid_lonlat(eg.ML)
_EPOINTS = np.array(list(zip(_EASE_LON.ravel(), _EASE_LAT.ravel())))
_EASE_NH_MASK = _EASE_LAT >= 0.0
_EASE_SH_MASK = _EASE_LAT < 0.0


def ft_model_zero_threshold(temps):
    return (temps > 273.15).astype("uint8")


def get_empty_data_grid(shape):
    return np.full(shape, OTHER, dtype="int8")


def get_empty_data_grid_like(a):
    return get_empty_data_grid(a.shape)
Example #12
0
def prep_data(
    region,
    start_date,
    end_date,
    am_pm,
    dest="../data/cleaned",
    db_path="../data/dbs/wmo_gsod.db",
    prep_era_t2m=False,
):
    if isinstance(start_date, int):
        start_date = _parse_date_arg(str(start_date))
    if isinstance(end_date, int):
        end_date = _parse_date_arg(str(end_date), False)
    assert start_date < end_date, "Start date must come before end date"

    transform = REGION_TO_TRANS[region]
    out_lon, out_lat = [
        transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)
    ]
    base_water_mask = np.load("../data/masks/ft_esdr_water_mask.npy")
    water_mask = transform(base_water_mask)
    land_mask = ~water_mask

    data = {}
    # Gap filled Tb
    tbdir = f"../data/tb/gapfilled_{region}_{am_pm.lower()}"
    paths = [
        f"{tbdir}/tb_{y}_{am_pm}_{region}_filled.npy"
        for y in range(start_date.year, end_date.year + 1)
    ]
    print("Loading gap-filled tb")
    tb = torch.utils.data.ConcatDataset([dh.NpyDataset(p) for p in paths])
    tb = dh.dataset_to_array(trim_datasets_to_dates(tb, start_date, end_date))
    data[TB_KEY] = tb
    # ERA5 FT
    # TODO: caching
    print("Loading ERA")
    era_ft = dh.dataset_to_array(
        trim_datasets_to_dates(
            dh.TransformPipelineDataset(
                dh.ERA5BidailyDataset(
                    [
                        f"../data/era5/t2m/bidaily/era5-t2m-bidaily-{y}.nc"
                        for y in range(start_date.year, end_date.year + 1)
                    ],
                    "t2m",
                    am_pm,
                    out_lon,
                    out_lat,
                ),
                [dh.FTTransform()],
            ),
            start_date,
            end_date,
        ))
    data[ERA_FT_KEY] = era_ft
    # ERA5 t2m
    # TODO: caching
    if prep_era_t2m:
        era_t2m = dh.dataset_to_array(
            trim_datasets_to_dates(
                dh.ERA5BidailyDataset(
                    [
                        f"../data/era5/t2m/bidaily/era5-t2m-bidaily-{y}.nc"
                        for y in range(start_date.year, end_date.year + 1)
                    ],
                    "t2m",
                    am_pm,
                    out_lon,
                    out_lat,
                ),
                start_date,
                end_date,
            ))
        data[ERA_T2M_KEY] = era_t2m
    sizes = set(len(d) for d in data.values())
    assert (len(sizes) == 1
            ), "All data must be the same length in the time dimension"
    prep(
        start_date,
        end_date,
        data,
        dest,
        region,
        land_mask,
        out_lon,
        out_lat,
        am_pm,
        db_path,
    )
Example #13
0
    get_db_session,
)
import ease_grid as eg

date = dt.date(2000, 1, 1)
db = get_db_session("../data/wmo_w_indexing.db")
sites = [(r.met_station.lon, r.met_station.lat, r.temperature_mean)
         for r in db.query(DbWMOMetDailyTempRecord).filter(
             DbWMOMetDailyTempRecord.date_int == date_to_int(date)).all()]
points = np.array([r[:-1] for r in sites])
px = points[:, 0]
py = points[:, 1]
pxm, pym = eg.v1_lonlat_to_meters(points[:, 0], points[:, 1])
pm = np.array(list(zip(pxm, pym)))
values = np.array([int(s[-1] > 273.15) for s in sites])
lons, lats = eg.v1_get_full_grid_lonlat(eg.ML)
xm, ym = eg.v1_lonlat_to_meters(lons, lats, eg.ML)
ip = NearestNDInterpolator(pm, values)
igrid = ip(xm, ym)
dist, _ = ip.tree.query(np.array(list(zip(xm.ravel(), ym.ravel()))))
dist = dist.reshape(xm.shape)

cmap = cmap = colors.ListedColormap(["skyblue", "lightcoral"])
norm = colors.BoundaryNorm([0, 1, 2], 2)

# Plot interpolation
ax = plt.axes(projection=ccrs.EckertVI())
plt.contourf(lons,
             lats,
             igrid,
             2,
Example #14
0
    DbWMOMetDailyTempRecord,
    DbWMOMetStation,
)
from transforms import N45_VIEW_TRANS as transform


pred = np.load("../runs/gscc/2020-11-02-19:51:31.769613-gold/pred.npy")
land_mask = ~transform(np.load("../data/masks/ft_esdr_water_mask.npy"))

dates, _, _ = dh.read_accuracies_file(
    "../runs/gscc/2020-11-02-19:51:31.769613-gold/acc.csv"
)
d2i = {d: i for i, d in enumerate(dates)}
assert len(dates) == len(pred)
d2p = {d: p.ravel() for d, p in zip(dates, pred)}
lon, lat = [transform(x) for x in eg.v1_get_full_grid_lonlat(eg.ML)]

df = dh.get_aws_full_data_for_dates(
    dates,
    "../data/dbs/wmo_gsod.db",
    land_mask,
    lon,
    lat,
    RETRIEVAL_MIN,
)
df = df[df.date != dt.date(2015, 1, 1)]
vres = {
    sid: np.zeros(len(dates), dtype=int) - 1 for sid in sorted(df.sid.unique())
}
ftres = {
    sid: np.zeros(len(dates), dtype=int) - 1 for sid in sorted(df.sid.unique())