def _latlon_sel_gen(varname, isel_dict, component, ensemble, entries): """ generate a values for a particular ensemble member, return a Dataset object """ print_timestamp(f"varname={varname}") varname_resolved = _varname_resolved(varname, component) fnames = entries.loc[entries["ensemble"] == ensemble].files.tolist() print(fnames) with open(var_specs_fname, mode="r") as fptr: var_specs_all = yaml.safe_load(fptr) if varname in var_specs_all[component]["vars"]: var_spec = var_specs_all[component]["vars"][varname] else: var_spec = {} ds_out_list = [] with xr.open_dataset(fnames[0]) as ds0: drop_var_names_loc = drop_var_names(component, ds0, varname_resolved) var_list = [time_name, varname] if "bounds" in ds0[time_name].attrs: var_list.append(ds0[time_name].attrs["bounds"]) var_list.extend(copy_var_names(component)) for fname in fnames: with xr.open_dataset(fname, drop_variables=drop_var_names_loc) as ds_in: ds_out = ds_in[var_list].isel(isel_dict) ds_out_list.append(ds_out.load()) ds_out = xr.concat(ds_out_list, dim=time_name, coords="minimal", compat="override") # restore encoding for time from first file ds_out[time_name].encoding = ds0[time_name].encoding # set ds_out.time to mid-interval values ds_out = time_set_mid(ds_out, time_name) # copy file attributes ds_out.attrs = ds0.attrs datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") msg = f"{datestamp}: created by {__file__}" if "history" in ds_out.attrs: ds_out.attrs["history"] = "\n".join([msg, ds_out.attrs["history"]]) else: ds_out.attrs["history"] = msg ds_out.attrs["input_file_list"] = " ".join(fnames) return ds_out
def test_time_set_mid(decode_times, deep, apply_chunk): ds = xr_ds_ex(decode_times, nyrs=nyrs, var_const=var_const, time_mid=False) if apply_chunk: ds = ds.chunk({"time": 12}) mid_month_values = gen_time_bounds_values(nyrs).mean(axis=1) if decode_times: time_encoding = ds["time"].encoding expected_values = cftime.num2date(mid_month_values, time_encoding["units"], time_encoding["calendar"]) else: expected_values = mid_month_values ds_out = time_set_mid(ds, "time", deep) assert ds_out.attrs == ds.attrs assert ds_out.encoding == ds.encoding assert ds_out.chunks == ds.chunks for varname in ds.variables: assert ds_out[varname].attrs == ds[varname].attrs assert ds_out[varname].encoding == ds[varname].encoding assert ds_out[varname].chunks == ds[varname].chunks if varname == "time": assert np.all(ds_out[varname].values == expected_values) else: assert np.all(ds_out[varname].values == ds[varname].values) assert (ds_out[varname].data is ds[varname].data) == (not deep) # verify that values are independent of ds being chunked in time ds_chunk = xr_ds_ex(decode_times, nyrs=nyrs, var_const=var_const, time_mid=False).chunk({"time": 6}) ds_chunk_out = time_set_mid(ds_chunk, "time") assert ds_chunk_out.identical(ds_out)
def compute_mon_anomaly(ds): """esmlab wrapper""" # esmlab.anomaly generates fatal error for particular time values # adjust them to avoid this tb_name = ds.time.bounds tb = ds[tb_name] time_decoded = tb.dtype == np.dtype("O") if time_decoded: tb_vals_encoded = cftime.date2num(ds[tb_name].values, ds.time.encoding["units"], ds.time.encoding["calendar"]) else: tb_vals_encoded = ds[tb_name].values val0 = -0.5 / 24.0 tb_vals_encoded = np.where( abs(tb_vals_encoded - val0) < 1.0e-10, val0, tb_vals_encoded) if time_decoded: ds[tb_name].values = cftime.num2date(tb_vals_encoded, ds.time.encoding["units"], ds.time.encoding["calendar"]) else: tb_vals_encoded = ds[tb_name].values = tb_vals_encoded ds = time_set_mid(ds, "time") # return esmlab.climatology.compute_mon_anomaly(ds) ds_out = anomaly(ds, clim_freq="mon") # propagate particular encoding values for key in ["unlimited_dims"]: if key in ds.encoding: ds_out.encoding[key] = ds.encoding[key] # copy file attributes, prepending history message for key in ds.attrs: if key == "history": datestamp = datetime.now( timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") msg = f"{datestamp}: created by esmlab.anomaly, with modifications from esmlab_wrap" ds_out.attrs[key] = "\n".join([msg, ds.attrs[key]]) else: ds_out.attrs[key] = ds.attrs[key] return ds_out
def _gen_ds_var_single_ensemble(varname, component, experiment, stream, df_in, ensemble): """ return xarray.Dataset containing varname for a single ensemble member """ # get DataFrame of matching data_catalog entries df = df_in.loc[df_in["ensemble"] == ensemble] if df.empty: raise ValueError( f"no file matches found for varname={varname}, component={component}, experiment={experiment}, ensemble={ensemble}" ) paths = df["files"].tolist() with xr.open_dataset(paths[0]) as ds0: rank = len(ds0[varname].dims) time_chunksize = 10 * 12 if rank < 4 else 12 time_encoding = ds0[time_name].encoding ds_encoding = ds0.encoding drop_var_names_loc = drop_var_names(component, ds0, varname) ds_out = xr.open_mfdataset( paths, data_vars="minimal", coords="minimal", compat="override", combine="by_coords", drop_variables=drop_var_names_loc, ).chunk(chunks={time_name: time_chunksize}) for key in ["units", "calendar"]: if key in time_encoding: ds_out[time_name].encoding[key] = time_encoding[key] # set ds_out.time to mid-interval values ds_out = time_set_mid(ds_out, time_name) for key in ["unlimited_dims"]: if key in ds_encoding: ds_out.encoding[key] = ds_encoding[key] return ds_out
def compute_ann_mean(ds): """esmlab wrapper""" # return esmlab.climatology.compute_ann_mean(ds) # ignore certain warnings with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", module=".*nanops") ds_out = resample(ds, freq="ann") # esmlab appears to corrupt xarray indexes wrt time # the following seems to reset them ds_out["time"] = ds_out["time"] # ensure time dim is first on time.bounds variable tb_name = ds_out.time.bounds if ds_out[tb_name].dims[0] != "time": ds_out[tb_name] = ds_out[tb_name].transpose() # reset time to midpoint ds_out = time_set_mid(ds_out, "time") # propagate particular encoding values for key in ["unlimited_dims"]: if key in ds.encoding: ds_out.encoding[key] = ds.encoding[key] # copy file attributes for key in ds.attrs: if key != "history": ds_out.attrs[key] = ds.attrs[key] # append to history file attribute if it already exists, otherwise set it key = "history" datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") msg = ( f"{datestamp}: created by esmlab.resample, with modifications from esmlab_wrap" ) if key in ds.attrs: ds_out.attrs[key] = "\n".join([msg, ds.attrs[key]]) else: ds_out.attrs[key] = msg return ds_out
def test_repl_coord(decode_times1, decode_times2, apply_chunk1): ds1 = time_set_mid(xr_ds_ex(decode_times1, nyrs=nyrs, var_const=var_const), "time") if apply_chunk1: ds1 = ds1.chunk({"time": 12}) # change time:bounds attribute variable rename corresponding variable tb_name_old = ds1["time"].attrs["bounds"] tb_name_new = tb_name_old + "_new" ds1["time"].attrs["bounds"] = tb_name_new ds1 = ds1.rename({tb_name_old: tb_name_new}) # verify that repl_coord on xr_ds_ex gives same results as # 1) executing time_set_mid # 2) manually changing bounds ds2 = repl_coord("time", ds1, xr_ds_ex(decode_times2, nyrs=nyrs, var_const=var_const)) assert ds2.identical(ds1) assert ds2["time"].encoding == ds1["time"].encoding assert ds2["time"].chunks == ds1["time"].chunks
def _tseries_gen(varname, component, ensemble, entries, cluster_in): """ generate a tseries for a particular ensemble member, return a Dataset object """ print_timestamp(f"varname={varname}") varname_resolved = _varname_resolved(varname, component) fnames = entries.loc[entries["ensemble"] == ensemble].files.tolist() print(fnames) with open(var_specs_fname, mode="r") as fptr: var_specs_all = yaml.safe_load(fptr) if varname in var_specs_all[component]["vars"]: var_spec = var_specs_all[component]["vars"][varname] else: var_spec = {} # use var specific reduce_dims if it exists, otherwise use reduce_dims for component if "reduce_dims" in var_spec: reduce_dims = var_spec["reduce_dims"] else: reduce_dims = var_specs_all[component]["reduce_dims"] # get rank of varname from first file, used to set time chunksize # approximate number of time levels, assuming all files have same number # save time encoding from first file, to restore it in the multi-file case # https://github.com/pydata/xarray/issues/2921 with xr.open_dataset(fnames[0]) as ds0: vardims = ds0[varname_resolved].dims rank = len(vardims) vertlen = ds0.dims[vardims[1]] if rank > 3 else 0 time_chunksize = 10 * 12 if rank < 4 else 6 ds0.chunk(chunks={time_name: time_chunksize}) time_encoding = ds0[time_name].encoding var_encoding = ds0[varname_resolved].encoding ds0_attrs = ds0.attrs ds0_encoding = ds0.encoding drop_var_names_loc = drop_var_names(component, ds0, varname_resolved) # instantiate cluster, if not provided via argument # ignore dashboard warnings when instantiating if cluster_in is None: if "ncar_jobqueue" in sys.modules: with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", module=".*dashboard") cluster = ncar_jobqueue.NCARCluster() else: raise ValueError( "cluster_in not provided and ncar_jobqueue did not load successfully" ) else: cluster = cluster_in workers = 12 if vertlen >= 20: workers *= 2 if vertlen >= 60: workers *= 2 workers = 2 * round(workers / 2) # round to nearest multiple of 2 print_timestamp(f"calling cluster.scale({workers})") cluster.scale(workers) print_timestamp(f"dashboard_link={cluster.dashboard_link}") # create dask distributed client, connecting to workers with dask.distributed.Client(cluster) as client: print_timestamp("client instantiated") # tool to help track down file inconsistencies that trigger errors in open_mfdataset # test_open_mfdataset(fnames, time_chunksize, varname) # data_vars = "minimal", to avoid introducing time dimension to time-invariant fields when there are multiple files # only chunk in time, because if you chunk over spatial dims, then sum results depend on chunksize # https://github.com/pydata/xarray/issues/2902 with xr.open_mfdataset( fnames, data_vars="minimal", coords="minimal", compat="override", combine="by_coords", chunks={time_name: time_chunksize}, drop_variables=drop_var_names_loc, ) as ds_in: print_timestamp("open_mfdataset returned") # restore encoding for time from first file ds_in[time_name].encoding = time_encoding da_in_full = ds_in[varname_resolved] da_in_full.encoding = var_encoding var_units = clean_units(da_in_full.attrs["units"]) if "unit_conv" in var_spec: var_units = f"({var_spec['unit_conv']})({var_units})" # construct averaging/integrating weight weight = get_weight(ds_in, component, reduce_dims) weight_attrs = weight.attrs weight = get_rmask(ds_in, component) * weight weight.attrs = weight_attrs print_timestamp("weight constructed") # compute regional sum of weights da_in_t0 = da_in_full.isel({time_name: 0}).drop(time_name) ones_masked_t0 = xr.ones_like(da_in_t0).where(da_in_t0.notnull()) weight_sum = (ones_masked_t0 * weight).sum(dim=reduce_dims) weight_sum.name = f"weight_sum_{varname}" weight_sum.attrs = weight.attrs weight_sum.attrs[ "long_name" ] = f"sum of weights used in tseries generation for {varname}" tlen = da_in_full.sizes[time_name] print_timestamp(f"tlen={tlen}") # use var specific tseries_op if it exists, otherwise use tseries_op for component if "tseries_op" in var_spec: tseries_op = var_spec["tseries_op"] else: tseries_op = var_specs_all[component]["tseries_op"] ds_out_list = [] time_step_nominal = min(2 * workers * time_chunksize, tlen) time_step = math.ceil(tlen / (tlen // time_step_nominal)) print_timestamp(f"time_step={time_step}") for time_ind0 in range(0, tlen, time_step): print_timestamp(f"time_ind={time_ind0}, {time_ind0 + time_step}") da_in = da_in_full.isel( {time_name: slice(time_ind0, time_ind0 + time_step)} ) if tseries_op == "integrate": da_out = (da_in * weight).sum(dim=reduce_dims) da_out.name = varname da_out.attrs["long_name"] = "Integrated " + da_in.attrs["long_name"] da_out.attrs["units"] = cf_units.Unit( f"({weight.attrs['units']})({var_units})" ).format() elif tseries_op == "average": da_out = (da_in * weight).sum(dim=reduce_dims) ones_masked = xr.ones_like(da_in).where(da_in.notnull()) denom = (ones_masked * weight).sum(dim=reduce_dims) da_out /= denom da_out.name = varname da_out.attrs["long_name"] = "Averaged " + da_in.attrs["long_name"] da_out.attrs["units"] = cf_units.Unit(var_units).format() else: msg = f"tseries_op={tseries_op} not implemented" raise NotImplementedError(msg) print_timestamp("da_out computation setup") # propagate some settings from da_in to da_out da_out.encoding["dtype"] = da_in.encoding["dtype"] copy_fill_settings(da_in, da_out) ds_out = da_out.to_dataset() print_timestamp("ds_out generated") # copy particular variables from ds_in copy_var_list = [time_name] if "bounds" in ds_in[time_name].attrs: copy_var_list.append(ds_in[time_name].attrs["bounds"]) copy_var_list.extend(copy_var_names(component)) ds_out = xr.merge( [ ds_out, ds_in[copy_var_list].isel( {time_name: slice(time_ind0, time_ind0 + time_step)} ), ] ) print_timestamp("copy_var_names added") # force computation of ds_out, while resources of client are still available print_timestamp("calling ds_out.load") ds_out_list.append(ds_out.load()) print_timestamp("returned from ds_out.load") print_timestamp("concatenating ds_out_list datasets") ds_out = xr.concat( ds_out_list, dim=time_name, data_vars=[varname], coords="minimal", compat="override", ) # set ds_out.time to mid-interval values ds_out = time_set_mid(ds_out, time_name) print_timestamp("time_set_mid returned") # copy file attributes ds_out.attrs = ds0_attrs datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") msg = f"{datestamp}: created by {__file__}" if "history" in ds_out.attrs: ds_out.attrs["history"] = "\n".join([msg, ds_out.attrs["history"]]) else: ds_out.attrs["history"] = msg ds_out.attrs["input_file_list"] = " ".join(fnames) for key in ["unlimited_dims"]: if key in ds0_encoding: ds_out.encoding[key] = ds0_encoding[key] # restore encoding for time from first file ds_out[time_name].encoding = time_encoding # change output units, if specified in var_spec units_key = ( "integral_display_units" if tseries_op == "integrate" else "display_units" ) if units_key in var_spec: ds_out[varname] = conv_units(ds_out[varname], var_spec[units_key]) print_timestamp("units converted") # add regional sum of weights ds_out[weight_sum.name] = weight_sum print_timestamp("ds_in and client closed") # if cluster was instantiated here, close it if cluster_in is None: cluster.close() return ds_out