def _latlon_sel_gen(varname, isel_dict, component, ensemble, entries):
    """
    generate a values for a particular ensemble member, return a Dataset object
    """
    print_timestamp(f"varname={varname}")
    varname_resolved = _varname_resolved(varname, component)
    fnames = entries.loc[entries["ensemble"] == ensemble].files.tolist()
    print(fnames)

    with open(var_specs_fname, mode="r") as fptr:
        var_specs_all = yaml.safe_load(fptr)

    if varname in var_specs_all[component]["vars"]:
        var_spec = var_specs_all[component]["vars"][varname]
    else:
        var_spec = {}

    ds_out_list = []

    with xr.open_dataset(fnames[0]) as ds0:
        drop_var_names_loc = drop_var_names(component, ds0, varname_resolved)
        var_list = [time_name, varname]
        if "bounds" in ds0[time_name].attrs:
            var_list.append(ds0[time_name].attrs["bounds"])
        var_list.extend(copy_var_names(component))
        for fname in fnames:
            with xr.open_dataset(fname,
                                 drop_variables=drop_var_names_loc) as ds_in:
                ds_out = ds_in[var_list].isel(isel_dict)
                ds_out_list.append(ds_out.load())

        ds_out = xr.concat(ds_out_list,
                           dim=time_name,
                           coords="minimal",
                           compat="override")

        # restore encoding for time from first file
        ds_out[time_name].encoding = ds0[time_name].encoding

        # set ds_out.time to mid-interval values
        ds_out = time_set_mid(ds_out, time_name)

        # copy file attributes
        ds_out.attrs = ds0.attrs

    datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
    msg = f"{datestamp}: created by {__file__}"
    if "history" in ds_out.attrs:
        ds_out.attrs["history"] = "\n".join([msg, ds_out.attrs["history"]])
    else:
        ds_out.attrs["history"] = msg

    ds_out.attrs["input_file_list"] = " ".join(fnames)

    return ds_out
def test_time_set_mid(decode_times, deep, apply_chunk):
    ds = xr_ds_ex(decode_times, nyrs=nyrs, var_const=var_const, time_mid=False)
    if apply_chunk:
        ds = ds.chunk({"time": 12})

    mid_month_values = gen_time_bounds_values(nyrs).mean(axis=1)
    if decode_times:
        time_encoding = ds["time"].encoding
        expected_values = cftime.num2date(mid_month_values,
                                          time_encoding["units"],
                                          time_encoding["calendar"])
    else:
        expected_values = mid_month_values

    ds_out = time_set_mid(ds, "time", deep)

    assert ds_out.attrs == ds.attrs
    assert ds_out.encoding == ds.encoding
    assert ds_out.chunks == ds.chunks

    for varname in ds.variables:
        assert ds_out[varname].attrs == ds[varname].attrs
        assert ds_out[varname].encoding == ds[varname].encoding
        assert ds_out[varname].chunks == ds[varname].chunks
        if varname == "time":
            assert np.all(ds_out[varname].values == expected_values)
        else:
            assert np.all(ds_out[varname].values == ds[varname].values)
            assert (ds_out[varname].data is ds[varname].data) == (not deep)

    # verify that values are independent of ds being chunked in time
    ds_chunk = xr_ds_ex(decode_times,
                        nyrs=nyrs,
                        var_const=var_const,
                        time_mid=False).chunk({"time": 6})
    ds_chunk_out = time_set_mid(ds_chunk, "time")
    assert ds_chunk_out.identical(ds_out)
Beispiel #3
0
def compute_mon_anomaly(ds):
    """esmlab wrapper"""

    # esmlab.anomaly generates fatal error for particular time values
    # adjust them to avoid this
    tb_name = ds.time.bounds
    tb = ds[tb_name]
    time_decoded = tb.dtype == np.dtype("O")
    if time_decoded:
        tb_vals_encoded = cftime.date2num(ds[tb_name].values,
                                          ds.time.encoding["units"],
                                          ds.time.encoding["calendar"])
    else:
        tb_vals_encoded = ds[tb_name].values
    val0 = -0.5 / 24.0
    tb_vals_encoded = np.where(
        abs(tb_vals_encoded - val0) < 1.0e-10, val0, tb_vals_encoded)
    if time_decoded:
        ds[tb_name].values = cftime.num2date(tb_vals_encoded,
                                             ds.time.encoding["units"],
                                             ds.time.encoding["calendar"])
    else:
        tb_vals_encoded = ds[tb_name].values = tb_vals_encoded
    ds = time_set_mid(ds, "time")

    #     return esmlab.climatology.compute_mon_anomaly(ds)
    ds_out = anomaly(ds, clim_freq="mon")

    # propagate particular encoding values
    for key in ["unlimited_dims"]:
        if key in ds.encoding:
            ds_out.encoding[key] = ds.encoding[key]

    # copy file attributes, prepending history message
    for key in ds.attrs:
        if key == "history":
            datestamp = datetime.now(
                timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
            msg = f"{datestamp}: created by esmlab.anomaly, with modifications from esmlab_wrap"
            ds_out.attrs[key] = "\n".join([msg, ds.attrs[key]])
        else:
            ds_out.attrs[key] = ds.attrs[key]

    return ds_out
def _gen_ds_var_single_ensemble(varname, component, experiment, stream, df_in,
                                ensemble):
    """
    return xarray.Dataset containing varname for a single ensemble member
    """
    # get DataFrame of matching data_catalog entries
    df = df_in.loc[df_in["ensemble"] == ensemble]

    if df.empty:
        raise ValueError(
            f"no file matches found for varname={varname}, component={component}, experiment={experiment}, ensemble={ensemble}"
        )

    paths = df["files"].tolist()

    with xr.open_dataset(paths[0]) as ds0:
        rank = len(ds0[varname].dims)
        time_chunksize = 10 * 12 if rank < 4 else 12
        time_encoding = ds0[time_name].encoding
        ds_encoding = ds0.encoding
        drop_var_names_loc = drop_var_names(component, ds0, varname)

    ds_out = xr.open_mfdataset(
        paths,
        data_vars="minimal",
        coords="minimal",
        compat="override",
        combine="by_coords",
        drop_variables=drop_var_names_loc,
    ).chunk(chunks={time_name: time_chunksize})

    for key in ["units", "calendar"]:
        if key in time_encoding:
            ds_out[time_name].encoding[key] = time_encoding[key]

    # set ds_out.time to mid-interval values
    ds_out = time_set_mid(ds_out, time_name)

    for key in ["unlimited_dims"]:
        if key in ds_encoding:
            ds_out.encoding[key] = ds_encoding[key]

    return ds_out
Beispiel #5
0
def compute_ann_mean(ds):
    """esmlab wrapper"""
    #     return esmlab.climatology.compute_ann_mean(ds)
    # ignore certain warnings
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", module=".*nanops")
        ds_out = resample(ds, freq="ann")

    # esmlab appears to corrupt xarray indexes wrt time
    # the following seems to reset them
    ds_out["time"] = ds_out["time"]

    # ensure time dim is first on time.bounds variable
    tb_name = ds_out.time.bounds
    if ds_out[tb_name].dims[0] != "time":
        ds_out[tb_name] = ds_out[tb_name].transpose()

    # reset time to midpoint
    ds_out = time_set_mid(ds_out, "time")

    # propagate particular encoding values
    for key in ["unlimited_dims"]:
        if key in ds.encoding:
            ds_out.encoding[key] = ds.encoding[key]

    # copy file attributes
    for key in ds.attrs:
        if key != "history":
            ds_out.attrs[key] = ds.attrs[key]

    # append to history file attribute if it already exists, otherwise set it
    key = "history"
    datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
    msg = (
        f"{datestamp}: created by esmlab.resample, with modifications from esmlab_wrap"
    )
    if key in ds.attrs:
        ds_out.attrs[key] = "\n".join([msg, ds.attrs[key]])
    else:
        ds_out.attrs[key] = msg

    return ds_out
def test_repl_coord(decode_times1, decode_times2, apply_chunk1):
    ds1 = time_set_mid(xr_ds_ex(decode_times1, nyrs=nyrs, var_const=var_const),
                       "time")
    if apply_chunk1:
        ds1 = ds1.chunk({"time": 12})

    # change time:bounds attribute variable rename corresponding variable
    tb_name_old = ds1["time"].attrs["bounds"]
    tb_name_new = tb_name_old + "_new"
    ds1["time"].attrs["bounds"] = tb_name_new
    ds1 = ds1.rename({tb_name_old: tb_name_new})

    # verify that repl_coord on xr_ds_ex gives same results as
    # 1) executing time_set_mid
    # 2) manually changing bounds
    ds2 = repl_coord("time", ds1,
                     xr_ds_ex(decode_times2, nyrs=nyrs, var_const=var_const))
    assert ds2.identical(ds1)

    assert ds2["time"].encoding == ds1["time"].encoding
    assert ds2["time"].chunks == ds1["time"].chunks
def _tseries_gen(varname, component, ensemble, entries, cluster_in):
    """
    generate a tseries for a particular ensemble member, return a Dataset object
    """
    print_timestamp(f"varname={varname}")
    varname_resolved = _varname_resolved(varname, component)
    fnames = entries.loc[entries["ensemble"] == ensemble].files.tolist()
    print(fnames)

    with open(var_specs_fname, mode="r") as fptr:
        var_specs_all = yaml.safe_load(fptr)

    if varname in var_specs_all[component]["vars"]:
        var_spec = var_specs_all[component]["vars"][varname]
    else:
        var_spec = {}

    # use var specific reduce_dims if it exists, otherwise use reduce_dims for component
    if "reduce_dims" in var_spec:
        reduce_dims = var_spec["reduce_dims"]
    else:
        reduce_dims = var_specs_all[component]["reduce_dims"]

    # get rank of varname from first file, used to set time chunksize
    # approximate number of time levels, assuming all files have same number
    # save time encoding from first file, to restore it in the multi-file case
    #     https://github.com/pydata/xarray/issues/2921
    with xr.open_dataset(fnames[0]) as ds0:
        vardims = ds0[varname_resolved].dims
        rank = len(vardims)
        vertlen = ds0.dims[vardims[1]] if rank > 3 else 0
        time_chunksize = 10 * 12 if rank < 4 else 6
        ds0.chunk(chunks={time_name: time_chunksize})
        time_encoding = ds0[time_name].encoding
        var_encoding = ds0[varname_resolved].encoding
        ds0_attrs = ds0.attrs
        ds0_encoding = ds0.encoding
        drop_var_names_loc = drop_var_names(component, ds0, varname_resolved)

    # instantiate cluster, if not provided via argument
    # ignore dashboard warnings when instantiating
    if cluster_in is None:
        if "ncar_jobqueue" in sys.modules:
            with warnings.catch_warnings():
                warnings.filterwarnings(action="ignore", module=".*dashboard")
                cluster = ncar_jobqueue.NCARCluster()
        else:
            raise ValueError(
                "cluster_in not provided and ncar_jobqueue did not load successfully"
            )
    else:
        cluster = cluster_in

    workers = 12
    if vertlen >= 20:
        workers *= 2
    if vertlen >= 60:
        workers *= 2
    workers = 2 * round(workers / 2)  # round to nearest multiple of 2
    print_timestamp(f"calling cluster.scale({workers})")
    cluster.scale(workers)

    print_timestamp(f"dashboard_link={cluster.dashboard_link}")

    # create dask distributed client, connecting to workers
    with dask.distributed.Client(cluster) as client:
        print_timestamp("client instantiated")

        # tool to help track down file inconsistencies that trigger errors in open_mfdataset
        # test_open_mfdataset(fnames, time_chunksize, varname)

        # data_vars = "minimal", to avoid introducing time dimension to time-invariant fields when there are multiple files
        # only chunk in time, because if you chunk over spatial dims, then sum results depend on chunksize
        #     https://github.com/pydata/xarray/issues/2902
        with xr.open_mfdataset(
            fnames,
            data_vars="minimal",
            coords="minimal",
            compat="override",
            combine="by_coords",
            chunks={time_name: time_chunksize},
            drop_variables=drop_var_names_loc,
        ) as ds_in:
            print_timestamp("open_mfdataset returned")

            # restore encoding for time from first file
            ds_in[time_name].encoding = time_encoding

            da_in_full = ds_in[varname_resolved]
            da_in_full.encoding = var_encoding

            var_units = clean_units(da_in_full.attrs["units"])
            if "unit_conv" in var_spec:
                var_units = f"({var_spec['unit_conv']})({var_units})"

            # construct averaging/integrating weight
            weight = get_weight(ds_in, component, reduce_dims)
            weight_attrs = weight.attrs
            weight = get_rmask(ds_in, component) * weight
            weight.attrs = weight_attrs
            print_timestamp("weight constructed")

            # compute regional sum of weights
            da_in_t0 = da_in_full.isel({time_name: 0}).drop(time_name)
            ones_masked_t0 = xr.ones_like(da_in_t0).where(da_in_t0.notnull())
            weight_sum = (ones_masked_t0 * weight).sum(dim=reduce_dims)
            weight_sum.name = f"weight_sum_{varname}"
            weight_sum.attrs = weight.attrs
            weight_sum.attrs[
                "long_name"
            ] = f"sum of weights used in tseries generation for {varname}"

            tlen = da_in_full.sizes[time_name]
            print_timestamp(f"tlen={tlen}")

            # use var specific tseries_op if it exists, otherwise use tseries_op for component
            if "tseries_op" in var_spec:
                tseries_op = var_spec["tseries_op"]
            else:
                tseries_op = var_specs_all[component]["tseries_op"]

            ds_out_list = []

            time_step_nominal = min(2 * workers * time_chunksize, tlen)
            time_step = math.ceil(tlen / (tlen // time_step_nominal))
            print_timestamp(f"time_step={time_step}")
            for time_ind0 in range(0, tlen, time_step):
                print_timestamp(f"time_ind={time_ind0}, {time_ind0 + time_step}")
                da_in = da_in_full.isel(
                    {time_name: slice(time_ind0, time_ind0 + time_step)}
                )

                if tseries_op == "integrate":
                    da_out = (da_in * weight).sum(dim=reduce_dims)
                    da_out.name = varname
                    da_out.attrs["long_name"] = "Integrated " + da_in.attrs["long_name"]
                    da_out.attrs["units"] = cf_units.Unit(
                        f"({weight.attrs['units']})({var_units})"
                    ).format()
                elif tseries_op == "average":
                    da_out = (da_in * weight).sum(dim=reduce_dims)
                    ones_masked = xr.ones_like(da_in).where(da_in.notnull())
                    denom = (ones_masked * weight).sum(dim=reduce_dims)
                    da_out /= denom
                    da_out.name = varname
                    da_out.attrs["long_name"] = "Averaged " + da_in.attrs["long_name"]
                    da_out.attrs["units"] = cf_units.Unit(var_units).format()
                else:
                    msg = f"tseries_op={tseries_op} not implemented"
                    raise NotImplementedError(msg)

                print_timestamp("da_out computation setup")

                # propagate some settings from da_in to da_out
                da_out.encoding["dtype"] = da_in.encoding["dtype"]
                copy_fill_settings(da_in, da_out)

                ds_out = da_out.to_dataset()

                print_timestamp("ds_out generated")

                # copy particular variables from ds_in
                copy_var_list = [time_name]
                if "bounds" in ds_in[time_name].attrs:
                    copy_var_list.append(ds_in[time_name].attrs["bounds"])
                copy_var_list.extend(copy_var_names(component))
                ds_out = xr.merge(
                    [
                        ds_out,
                        ds_in[copy_var_list].isel(
                            {time_name: slice(time_ind0, time_ind0 + time_step)}
                        ),
                    ]
                )

                print_timestamp("copy_var_names added")

                # force computation of ds_out, while resources of client are still available
                print_timestamp("calling ds_out.load")
                ds_out_list.append(ds_out.load())
                print_timestamp("returned from ds_out.load")

            print_timestamp("concatenating ds_out_list datasets")
            ds_out = xr.concat(
                ds_out_list,
                dim=time_name,
                data_vars=[varname],
                coords="minimal",
                compat="override",
            )

            # set ds_out.time to mid-interval values
            ds_out = time_set_mid(ds_out, time_name)

            print_timestamp("time_set_mid returned")

            # copy file attributes
            ds_out.attrs = ds0_attrs

            datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
            msg = f"{datestamp}: created by {__file__}"
            if "history" in ds_out.attrs:
                ds_out.attrs["history"] = "\n".join([msg, ds_out.attrs["history"]])
            else:
                ds_out.attrs["history"] = msg

            ds_out.attrs["input_file_list"] = " ".join(fnames)

            for key in ["unlimited_dims"]:
                if key in ds0_encoding:
                    ds_out.encoding[key] = ds0_encoding[key]

            # restore encoding for time from first file
            ds_out[time_name].encoding = time_encoding

            # change output units, if specified in var_spec
            units_key = (
                "integral_display_units"
                if tseries_op == "integrate"
                else "display_units"
            )
            if units_key in var_spec:
                ds_out[varname] = conv_units(ds_out[varname], var_spec[units_key])
                print_timestamp("units converted")

            # add regional sum of weights
            ds_out[weight_sum.name] = weight_sum

    print_timestamp("ds_in and client closed")

    # if cluster was instantiated here, close it
    if cluster_in is None:
        cluster.close()

    return ds_out