def perform_custom_validation(db, fpaths, dates, start_row, start_col, comp_type, mask=None): validate_file_path_list(fpaths) if len(dates) != len(fpaths): raise RuntimeError("Input paths size does not match input dates size") if len(fpaths) == 0: print("No data") return pf = WMOValidationPointFetcher(db, retrieval_type=comp_type) data = load_npy_files(fpaths) _verify_grids_are_homogenous_shape(data) nr, nc = data[0].shape xj, xi = np.meshgrid(range(start_col, start_col + nc), range(start_row, start_row + nr)) if mask is not None: mask = mask[xi, xj] lon, lat = eg.v1_get_full_grid_lonlat(eg.ML) lon = lon[xi, xj] lat = lat[xi, xj] bounds = [lon.min(), lon.max(), lat.min(), lat.max()] pg = PointsGridder(lon, lat, invalid_mask=mask) date_to_grid = {d: g for d, g in zip(dates, data)} results = perform_bounded_validation(date_to_grid, pf, pg, bounds) output_validation_stats_grouped_by_month(results, [COL_DATE, COL_SCORE])
def main(args): if args.start_year > args.end_year: raise ValueError("Start year must be less than or equal to end_year") year_str = get_year_str(args.start_year, args.end_year) if args.region is not None: transform = REGION_TO_TRANS[args.region] land_mask = ~transform(np.load("../data/masks/ft_esdr_water_mask.npy")) lon, lat = [transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)] else: assert args.lon is not None and args.lat is not None land_mask = ~args.water lon = args.lon lat = args.lat # TODO: add option for AM/PM ret_type = RETRIEVAL_MIN dates = load_dates( args.dates or f"../data/cleaned/date_map-{year_str}-{args.region}.csv") # TODO: add option for AM/PM db_path = args.db_path or "../data/dbs/wmo_gsod.db" aws_data = get_aws_data( dates, db_path, land_mask, lon, lat, ret_type, ) # TODO: add option for AM/PM out_file = os.path.join(args.out_dir, f"aws_data-AM-{year_str}-{args.region}.pkl") print(f"Saving data to '{out_file}'") persist_data_object(aws_data, out_file, overwrite=True)
def main(args): cfile = os.path.join(args.target_dir, "config.yaml") if args.config is not None: # Already validated by parser cfile = args.config else: try: utils.validate_file_path(cfile) except IOError: cfile = os.path.join(args.target_dir, "config") utils.validate_file_path(cfile) config = load_config(cfile) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") transform = REGION_TO_TRANS[config.region] lon, lat = [transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML)] model_path = os.path.join(args.target_dir, "model.pt") utils.validate_file_path(model_path) model_class = UNet if args.legacy_model: print("Using legacy model") model_class = UNetLegacy model = model_class( in_chan=config.in_chan, n_classes=config.n_classes, depth=config.depth, base_filter_bank_size=config.base_filters, skip=config.skips, bndry_dropout=config.bndry_dropout, bndry_dropout_p=config.bndry_dropout_p, ) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = model.to(device) model_dict = torch.load(model_path) model_load(model, model_dict) model.eval() input_ds = build_input_dataset_form_config(config, False, lat) batch_size = args.batch_size if args.batch_size > 0 else config.batch_size input_dl = torch.utils.data.DataLoader(input_ds, batch_size=batch_size, shuffle=False, drop_last=False) preds, probs = get_predictions( input_dl, model, transform(np.load("../data/masks/ft_esdr_water_mask.npy")), LABEL_OTHER, device, config, ) # TODO pred_path = os.path.join(args.target_dir, "pred.npy") prob_path = os.path.join(args.target_dir, "prob.npy") print(f"Saving probabilities: '{prob_path}'") np.save(prob_path, probs) if args.save_predictions: print(f"Saving predictions: '{pred_path}'") np.save(preds, pred_path)
def load_data(files, proj): if not files: raise ValueError("files must not be empty") dates = [_parse_date_from_fname(f) for f in files] grids = [tbmod.load_tb_file(f, proj) for f in files] dates, grids, missing_mask = _fill_gaps(dates, grids) lon, lat = eg.v1_get_full_grid_lonlat(proj) x, y = eg.v1_lonlat_to_meters(lon, lat, proj) lon = lon[0] lat = lat[:, 0] x = x[0] y = y[:, 0] return dates, lon, lat, x, y, grids, missing_mask
def main(args): with rio.open(args.meta_tif) as tif: proj = SNOW_4K_PROJ year = int(os.path.basename(os.path.dirname(args.in_dir))) ndays = 365 + int(calendar.isleap(year)) files = sorted(glob.glob(os.path.join(args.in_dir, "*.asc.gz"))) idxs = [_parse_index(f) for f in files] lon, lat = [ NH_VIEW_TRANS(i) for i in eg.v1_get_full_grid_lonlat(eg.ML) ] data = np.full((ndays, *lon.shape), SNOW_MISSING_VALUE, dtype=np.int8) fill_data_array(data, files, idxs, tif, proj, lon, lat) print(f"Saving data to disk: '{args.out_file}'") np.save(args.out_file, data)
def perform_validation_on_ft_esdr(db, fpaths, mask=None): for f in fpaths: validate_file_path(f) pf = WMOValidationPointFetcher(db) pg = PointsGridder(*eg.v1_get_full_grid_lonlat(eg.ML), invalid_mask=mask) data = load_ft_esdr_data_from_files(fpaths) dates_am = [d.dt for d in data if d.am_grid is not None] dates_pm = [d.dt for d in data if d.pm_grid is not None] grids_am = [d.am_grid for d in data if d.am_grid is not None] grids_pm = [d.pm_grid for d in data if d.pm_grid is not None] grids_am = {k: v for k, v in zip(dates_am, grids_am)} grids_pm = {k: v for k, v in zip(dates_pm, grids_pm)} perform_default_am_pm_validation(grids_am, grids_pm, pf, pg)
def main(args): data = np.load(args.in_file).astype(args.type) if len(data.shape) > 3 or len(data.shape) < 2: raise UnsupportedNumberDimsError( f"Number of dims must be 2 or 3. Got {len(data.shape)}." ) elif len(data.shape) == 3: nbands = data.shape[0] else: nbands = 1 # Add extra axis for iteration purposes data = np.expand_dims(data, axis=0) trans = REGION_TO_TRANS[args.region] if args.mask is not None: mask = trans(np.load(args.mask).astype(bool)) data[..., mask] = args.missing_value crs = eg.GRID_NAME_TO_V1_PROJ[eg.ML] x, y = [ trans(xi) for xi in eg.v1_lonlat_to_meters(*eg.v1_get_full_grid_lonlat(eg.ML)) ] x = x[0] y = y[:, 0] xres = (x[-1] - x[0]) / len(x) yres = (y[-1] - y[0]) / len(y) t = Affine.translation(x[0], y[0]) * Affine.scale(xres, yres) ds = rio.open( args.out_file, "w", driver="GTiff", height=data.shape[1], width=data.shape[2], count=nbands, dtype=args.type, crs=crs.srs, transform=t, compress="lzw", nodata=args.missing_value, ) for i, band in enumerate(data): ds.write(band, i + 1) ds.close()
def worker(args): ( pred_path, am_pm, region, mask_code, out_path, ) = args db = get_db_session("../data/dbs/wmo_gsod.db") dates = date_range(dt.date(1988, 1, 2), dt.date(2019, 1, 1)) trans = REGION_TO_TRANS[region] lon, lat = [trans(x) for x in eg.v1_get_full_grid_lonlat(eg.ML)] land = trans(np.load("../data/masks/ft_esdr_land_mask.npy")) water = ~land non_cc_mask = trans( np.load("../data/masks/ft_esdr_non_cold_constrained_mask.npy")) invalid = non_cc_mask | water cc_mask = ~invalid inv_cc_mask = land & ~cc_mask mask = None if mask_code == CC_MASK: mask = cc_mask elif mask_code == LAND_MASK: mask = land else: mask = inv_cc_mask pred = trans(np.load(pred_path)) df = validate_against_aws_db(pred, db, dates, lon, lat, mask, am_pm, progress=False) df.to_csv(out_path)
def main(args): db = get_db_session(args.db_path) out_dir = args.out_dir if not os.path.isdir(out_dir): os.makedirs(out_dir) interp_plot_dir = os.path.join(out_dir, "interp") if not os.path.isdir(interp_plot_dir): os.makedirs(interp_plot_dir) dist_plot_dir = os.path.join(out_dir, "dist") if not os.path.isdir(dist_plot_dir): os.makedirs(dist_plot_dir) start_date = dt.date(2000, 7, 1) end_date = dt.date(2001, 7, 1) dt_delta = dt.timedelta(1) dates, records = load_records(db, start_date, end_date, dt_delta) lons, lats = eg.v1_get_full_grid_lonlat(eg.ML) x, y = eg.v1_lonlat_to_meters(lons, lats, eg.ML) interp_grids = np.zeros((len(records), *x.shape)) dist_grids = np.zeros_like(interp_grids) fill_grids(x, y, records, interp_grids, dist_grids) dmin = dist_grids.min() dmax = dist_grids.max() for i in range(len(dates)): plot( dates[i], lons, lats, interp_grids[i], dist_grids[i], interp_plot_dir, dist_plot_dir, dmin, dmax, ) db.close()
def main(args): lon, lat = eg.v1_get_full_grid_lonlat(eg.ML) dates = get_dates_for_year(args.year) rads = get_daily_radiation_parallel(dates, lon, lat, args.workers) print(f"Saving to: {args.out_file}") np.save(args.out_file, rads)
TYPE_PM = "PM" # Composite TYPE_CO = "CO" OTHER = -1 FROZEN = 0 THAWED = 1 FT_ESDR_FROZEN = 0 FT_ESDR_THAWED = 1 # Frozen in AM, thawed in PM FT_ESDR_TRANSITIONAL = 2 # Thawed in AM, frozen in PM FT_ESDR_INV_TRANSITIONAL = 3 _EASE_LON, _EASE_LAT = eg.v1_get_full_grid_lonlat(eg.ML) _EPOINTS = np.array(list(zip(_EASE_LON.ravel(), _EASE_LAT.ravel()))) _EASE_NH_MASK = _EASE_LAT >= 0.0 _EASE_SH_MASK = _EASE_LAT < 0.0 def ft_model_zero_threshold(temps): return (temps > 273.15).astype("uint8") def get_empty_data_grid(shape): return np.full(shape, OTHER, dtype="int8") def get_empty_data_grid_like(a): return get_empty_data_grid(a.shape)
def prep_data( region, start_date, end_date, am_pm, dest="../data/cleaned", db_path="../data/dbs/wmo_gsod.db", prep_era_t2m=False, ): if isinstance(start_date, int): start_date = _parse_date_arg(str(start_date)) if isinstance(end_date, int): end_date = _parse_date_arg(str(end_date), False) assert start_date < end_date, "Start date must come before end date" transform = REGION_TO_TRANS[region] out_lon, out_lat = [ transform(i) for i in eg.v1_get_full_grid_lonlat(eg.ML) ] base_water_mask = np.load("../data/masks/ft_esdr_water_mask.npy") water_mask = transform(base_water_mask) land_mask = ~water_mask data = {} # Gap filled Tb tbdir = f"../data/tb/gapfilled_{region}_{am_pm.lower()}" paths = [ f"{tbdir}/tb_{y}_{am_pm}_{region}_filled.npy" for y in range(start_date.year, end_date.year + 1) ] print("Loading gap-filled tb") tb = torch.utils.data.ConcatDataset([dh.NpyDataset(p) for p in paths]) tb = dh.dataset_to_array(trim_datasets_to_dates(tb, start_date, end_date)) data[TB_KEY] = tb # ERA5 FT # TODO: caching print("Loading ERA") era_ft = dh.dataset_to_array( trim_datasets_to_dates( dh.TransformPipelineDataset( dh.ERA5BidailyDataset( [ f"../data/era5/t2m/bidaily/era5-t2m-bidaily-{y}.nc" for y in range(start_date.year, end_date.year + 1) ], "t2m", am_pm, out_lon, out_lat, ), [dh.FTTransform()], ), start_date, end_date, )) data[ERA_FT_KEY] = era_ft # ERA5 t2m # TODO: caching if prep_era_t2m: era_t2m = dh.dataset_to_array( trim_datasets_to_dates( dh.ERA5BidailyDataset( [ f"../data/era5/t2m/bidaily/era5-t2m-bidaily-{y}.nc" for y in range(start_date.year, end_date.year + 1) ], "t2m", am_pm, out_lon, out_lat, ), start_date, end_date, )) data[ERA_T2M_KEY] = era_t2m sizes = set(len(d) for d in data.values()) assert (len(sizes) == 1 ), "All data must be the same length in the time dimension" prep( start_date, end_date, data, dest, region, land_mask, out_lon, out_lat, am_pm, db_path, )
get_db_session, ) import ease_grid as eg date = dt.date(2000, 1, 1) db = get_db_session("../data/wmo_w_indexing.db") sites = [(r.met_station.lon, r.met_station.lat, r.temperature_mean) for r in db.query(DbWMOMetDailyTempRecord).filter( DbWMOMetDailyTempRecord.date_int == date_to_int(date)).all()] points = np.array([r[:-1] for r in sites]) px = points[:, 0] py = points[:, 1] pxm, pym = eg.v1_lonlat_to_meters(points[:, 0], points[:, 1]) pm = np.array(list(zip(pxm, pym))) values = np.array([int(s[-1] > 273.15) for s in sites]) lons, lats = eg.v1_get_full_grid_lonlat(eg.ML) xm, ym = eg.v1_lonlat_to_meters(lons, lats, eg.ML) ip = NearestNDInterpolator(pm, values) igrid = ip(xm, ym) dist, _ = ip.tree.query(np.array(list(zip(xm.ravel(), ym.ravel())))) dist = dist.reshape(xm.shape) cmap = cmap = colors.ListedColormap(["skyblue", "lightcoral"]) norm = colors.BoundaryNorm([0, 1, 2], 2) # Plot interpolation ax = plt.axes(projection=ccrs.EckertVI()) plt.contourf(lons, lats, igrid, 2,
DbWMOMetDailyTempRecord, DbWMOMetStation, ) from transforms import N45_VIEW_TRANS as transform pred = np.load("../runs/gscc/2020-11-02-19:51:31.769613-gold/pred.npy") land_mask = ~transform(np.load("../data/masks/ft_esdr_water_mask.npy")) dates, _, _ = dh.read_accuracies_file( "../runs/gscc/2020-11-02-19:51:31.769613-gold/acc.csv" ) d2i = {d: i for i, d in enumerate(dates)} assert len(dates) == len(pred) d2p = {d: p.ravel() for d, p in zip(dates, pred)} lon, lat = [transform(x) for x in eg.v1_get_full_grid_lonlat(eg.ML)] df = dh.get_aws_full_data_for_dates( dates, "../data/dbs/wmo_gsod.db", land_mask, lon, lat, RETRIEVAL_MIN, ) df = df[df.date != dt.date(2015, 1, 1)] vres = { sid: np.zeros(len(dates), dtype=int) - 1 for sid in sorted(df.sid.unique()) } ftres = { sid: np.zeros(len(dates), dtype=int) - 1 for sid in sorted(df.sid.unique())