예제 #1
0
def insert_prediction(ds: xr.Dataset, ds_pred: xr.Dataset) -> xr.Dataset:
    predicted_vars = ds_pred.data_vars
    nonpredicted_vars = [
        var for var in ds.data_vars if var not in predicted_vars
    ]
    ds_target = (safe.get_variables(
        ds, [var for var in predicted_vars if var in ds.data_vars
             ]).expand_dims(DERIVATION_DIM_NAME).assign_coords(
                 {DERIVATION_DIM_NAME: [TARGET_COORD]}))
    ds_pred = ds_pred.expand_dims(DERIVATION_DIM_NAME).assign_coords(
        {DERIVATION_DIM_NAME: [PREDICT_COORD]})
    return xr.merge(
        [safe.get_variables(ds, nonpredicted_vars), ds_target, ds_pred])
예제 #2
0
 def _merge_with_overlap(self, datasets: Sequence[xr.Dataset]) -> xr.Dataset:
     ds_nonoverlap = xr.merge(
         [ds.drop_vars(list(self._var_overlap)) for ds in datasets]
     )
     overlapping = []
     for ds, source_coord in zip(datasets, self._source_names):
         if self._overlap_dim in ds.dims:
             overlapping.append(safe.get_variables(ds, self._var_overlap))
         else:
             overlapping.append(
                 safe.get_variables(ds, self._var_overlap)
                 .expand_dims(self._overlap_dim)
                 .assign_coords({self._overlap_dim: [source_coord]})
             )
     return xr.merge(overlapping + [ds_nonoverlap])
예제 #3
0
def _consolidate_dimensioned_data(ds):
    # moves dimensioned quantities into final diags dataset so they're saved as netcdf
    scalar_metrics = [var for var in ds if ds[var].size == len(ds.batch)]
    ds_scalar_metrics = safe.get_variables(ds, scalar_metrics)
    ds_metrics_arrays = ds.drop(scalar_metrics)
    ds_diagnostics = ds.merge(ds_metrics_arrays)
    return ds_diagnostics, ds_scalar_metrics
예제 #4
0
def _compute_diagnostics(
    batches: Sequence[xr.Dataset],
    grid: xr.Dataset,
    predicted_vars: List[str],
    n_jobs: int,
) -> Tuple[xr.Dataset, xr.Dataset]:
    batches_summary = []

    # for each batch...
    for i, ds in enumerate(batches):
        logger.info(f"Processing batch {i+1}/{len(batches)}")

        # ...insert additional variables
        diagnostic_vars_3d = [var for var in predicted_vars if is_3d(ds[var])]
        ds = ds.pipe(insert_column_integrated_vars, diagnostic_vars_3d).load()

        full_predicted_vars = [
            var for var in ds if DERIVATION_DIM_NAME in ds[var].dims
        ]
        if "dQ2" in full_predicted_vars or "Q2" in full_predicted_vars:
            full_predicted_vars.append("water_vapor_path")
        prediction = safe.get_variables(
            ds.sel({DERIVATION_DIM_NAME: PREDICT_COORD}), full_predicted_vars)
        target = safe.get_variables(
            ds.sel({DERIVATION_DIM_NAME: TARGET_COORD}), full_predicted_vars)
        ds_summary = compute_diagnostics(prediction,
                                         target,
                                         grid,
                                         ds[DELP],
                                         n_jobs=n_jobs)
        ds_summary["time"] = ds["time"]

        batches_summary.append(ds_summary.load())
        del ds

    # then average over the batches for each output
    ds_summary = xr.concat(batches_summary, dim="batch")
    ds_diagnostics, ds_scalar_metrics = _consolidate_dimensioned_data(
        ds_summary)

    ds_scalar_metrics = insert_r2(ds_scalar_metrics)
    ds_diagnostics = ds_diagnostics.pipe(insert_r2).pipe(insert_rmse)
    ds_diagnostics, ds_scalar_metrics = _standardize_names(
        ds_diagnostics, ds_scalar_metrics)
    # this is kept as a coord to use in plotting a histogram of test timesteps
    return ds_diagnostics.mean("batch"), ds_scalar_metrics
예제 #5
0
    def predict(self, X: xr.Dataset) -> xr.Dataset:
        stacked_data = stack_non_vertical(
            safe.get_variables(X, self.input_variables))

        stacked_output = self._predict_on_stacked_data(stacked_data)
        unstacked_output = stacked_output.assign_coords({
            SAMPLE_DIM_NAME:
            stacked_data[SAMPLE_DIM_NAME]
        }).unstack(SAMPLE_DIM_NAME)

        return match_prediction_to_input_coords(X, unstacked_output)
예제 #6
0
 def transform(ds):
     # Prioritize dataset's land_sea_mask if grid values disagree
     ds = xr.merge(
         [ds, grid],
         compat="override"  # type: ignore
     )
     derived_mapping = DerivedMapping(ds)
     ds_derived = derived_mapping.dataset(variables)
     ds_prediction = predictor.predict_columnwise(safe.get_variables(
         ds_derived, variables),
                                                  feature_dim="z")
     return insert_prediction(ds_derived, ds_prediction)
예제 #7
0
def _get_prescribed_ds(
    dataset_key: str,
    variables: Sequence[str],
    timesteps: Optional[Sequence[cftime.DatetimeJulian]],
    consolidated: bool = True,
) -> Tuple[xr.Dataset, xr.DataArray]:
    logger.info(f"Setting up dataset for state setting: {dataset_key}")
    ds = _open_ds(dataset_key, consolidated)
    ds = get_variables(ds, variables)
    if timesteps is not None:
        ds = _time_interpolate_data(ds, timesteps, variables)
    time_coord = ds.coords["time"]
    return ds.drop_vars(names="time").load(), time_coord
예제 #8
0
def _get_batch(
    mapper: Mapping[str, xr.Dataset],
    data_vars: Sequence[str],
    keys: Iterable[str],
) -> xr.Dataset:
    """
    Selects requested variables in the dataset that are there by default
    (i.e., not added in derived step), converts time strings to time, and combines
    into a single dataset.
    """
    time_coords = [parse_datetime_from_str(key) for key in keys]
    ds = xr.concat([mapper[key] for key in keys],
                   pd.Index(time_coords, name=TIME_NAME))
    nonderived_vars = nonderived_variables(data_vars, tuple(ds.data_vars))
    ds = safe.get_variables(ds, nonderived_vars)
    return ds
예제 #9
0
def conditional_average_over_domain(
    ds: xr.Dataset,
    grid: xr.Dataset,
    primary_vars: Sequence[str],
    domain: str,
    net_precipitation: Optional[xr.DataArray] = None,
    uninformative_coords: Sequence[str] = ["tile", "z", "y", "x"],
) -> xr.Dataset:
    """Reduce a sequence of batches to a diagnostic dataset
    
    Args:
        ds: xarray datasets with relevant variables batched in time
        grid: xarray dataset containing grid variables
        (latb, lonb, lat, lon, area, land_sea_mask)
        domains: sequence of area domains over which to produce conditional
            averages; optional, defaults to global, land, sea, and positive and
            negative net_precipitation domains
        primary_vars: sequence of variables for which to compute column integrals
            and composite means; optional, defaults to dQs, pQs and Qs
        net_precipitation: xr.DataArray of net_precipitation values for computing
            composites, typically supplied by SHiELD net_precipitation; optional
        uninformative_coords: sequence of names of uninformative (i.e.,
            range(len(dim))), coordinates to be dropped
            
    Returns:
        diagnostic_ds: xarray dataset of reduced diagnostic variables
    """

    ds = ds.drop_vars(names=uninformative_coords, errors="ignore")

    grid = grid.drop_vars(names=uninformative_coords, errors="ignore")
    surface_type_array = _snap_mask_to_type(grid[SURFACE_TYPE])
    if "net_precipitation" in domain:
        net_precipitation_type_array = _snap_net_precipitation_to_type(
            net_precipitation
        )
        cell_type = net_precipitation_type_array.drop_vars(
            names=uninformative_coords, errors="ignore"
        )
    else:
        cell_type = surface_type_array
    domain_average = _conditional_average(
        safe.get_variables(ds, primary_vars), cell_type, domain, grid["area"],
    )
    return domain_average.mean("time")
예제 #10
0
def _get_transect(ds_snapshot: xr.Dataset, grid: xr.Dataset,
                  variables: Sequence[str]):
    ds_snapshot_regrid_pressure = xr.Dataset()
    for var in variables:
        transect_var = [
            interpolate_to_pressure_levels(
                field=ds_snapshot[var].sel(derivation=deriv),
                delp=ds_snapshot[DELP],
                dim="z",
            ) for deriv in ["target", "predict"]
        ]
        ds_snapshot_regrid_pressure[var] = xr.concat(transect_var,
                                                     dim="derivation")
    ds_snapshot_regrid_pressure = xr.merge([ds_snapshot_regrid_pressure, grid])
    ds_transect = meridional_transect(
        safe.get_variables(ds_snapshot_regrid_pressure,
                           list(variables) + ["lat", "lon"]))
    return ds_transect
예제 #11
0
def data_dirs(tmpdir_factory):

    tmpdir = tmpdir_factory.mktemp("input_data")

    variables = [
        "grid_lat_coarse",
        "grid_latt_coarse",
        "grid_lon_coarse",
        "grid_lont_coarse",
    ]
    # use a small tile for much faster testing
    n = 48

    diag_selectors = dict(tile=[0],
                          time=[0, 1],
                          grid_xt_coarse=slice(0, n),
                          grid_yt_coarse=slice(0, n))

    restart_selectors = dict(tile=[0],
                             time=[0, 1, 2],
                             grid_xt=slice(0, n),
                             grid_yt=slice(0, n))

    diags = safe.get_variables(open_schema("diag.json"),
                               budget.config.PHYSICS_VARIABLES +
                               variables).isel(diag_selectors)
    restart = open_schema("restart.json").isel(restart_selectors)
    gfsphysics = open_schema("gfsphysics.json").isel(diag_selectors)
    area = (open_schema("area.json").isel(tile=[0],
                                          grid_xt_coarse=slice(0, n),
                                          grid_yt_coarse=slice(0, n)).load())

    diag_path = str(tmpdir.join("diag.zarr"))
    restart_path = str(tmpdir.join("restart.zarr"))
    gfsphysics_path = str(tmpdir.join("gfsphysics.zarr"))
    area_path = str(tmpdir.join("area.zarr"))

    diags.to_zarr(diag_path, mode="w", consolidated=True)
    restart.to_zarr(restart_path, mode="w", consolidated=True)
    gfsphysics.to_zarr(gfsphysics_path, mode="w", consolidated=True)
    area.to_zarr(area_path, consolidated=True, mode="w")

    return diag_path, restart_path, gfsphysics_path, area_path
예제 #12
0
    def predict(self, X: xr.Dataset) -> xr.Dataset:
        """Predict an output xarray dataset from an input xarray dataset."""
        stacked_X = stack_non_vertical(
            safe.get_variables(X, self.input_variables))
        n_samples = len(stacked_X[SAMPLE_DIM_NAME])
        data_vars = {}
        for name in self.output_variables:
            output = self._outputs.get(name, 0.0)
            if isinstance(output, np.ndarray):
                array = np.repeat(output[None, :], repeats=n_samples, axis=0)
                data_vars[name] = xr.DataArray(data=array,
                                               dims=[SAMPLE_DIM_NAME, "z"])
            else:
                array = np.full([n_samples], float(output))
                data_vars[name] = xr.DataArray(data=array,
                                               dims=[SAMPLE_DIM_NAME])
        coords: Optional[Mapping[Hashable, Any]] = {
            SAMPLE_DIM_NAME: stacked_X.coords[SAMPLE_DIM_NAME]
        }

        pred = xr.Dataset(data_vars=data_vars,
                          coords=coords).unstack(SAMPLE_DIM_NAME)
        return match_prediction_to_input_coords(X, pred)
예제 #13
0
def main(args):
    logger.info("Starting diagnostics routine.")

    with fsspec.open(args.data_yaml, "r") as f:
        as_dict = yaml.safe_load(f)
    config = loaders.BatchesLoader.from_dict(as_dict)

    logger.info("Reading grid...")
    if not args.grid:
        # By default, read the appropriate resolution grid from vcm.catalog
        grid = load_grid_info(args.grid_resolution)
    else:
        with fsspec.open(args.grid, "rb") as f:
            grid = xr.open_dataset(f, engine="h5netcdf").load()

    logger.info("Opening ML model")
    model = fv3fit.load(args.model_path)

    # add Q2 and total water path for PW-Q2 scatterplots and net precip domain averages
    if any(["Q2" in v for v in model.output_variables]):
        model = fv3fit.DerivedModel(model, derived_output_variables=["Q2"])
        model_variables = _variables_to_load(model) + ["water_vapor_path"]
    else:
        model_variables = _variables_to_load(model)

    output_data_yaml = os.path.join(args.output_path, "data_config.yaml")
    with fsspec.open(args.data_yaml,
                     "r") as f_in, fsspec.open(output_data_yaml, "w") as f_out:
        f_out.write(f_in.read())
    batches = config.load_batches(model_variables)
    predict_function = _get_predict_function(model, model_variables, grid)
    batches = loaders.Map(predict_function, batches)

    # compute diags
    ds_diagnostics, ds_scalar_metrics = _compute_diagnostics(
        batches,
        grid,
        predicted_vars=model.output_variables,
        n_jobs=args.n_jobs)
    ds_diagnostics = ds_diagnostics.update(grid)

    # save model senstivity figures- these exclude derived variables
    base_model = model.base_model if isinstance(model,
                                                fv3fit.DerivedModel) else model
    try:
        plot_jacobian(
            base_model,
            os.path.join(args.output_path,
                         "model_sensitivity_figures"),  # type: ignore
        )
    except AttributeError:
        try:
            input_feature_indices = get_variable_indices(
                data=batches[0], variables=base_model.input_variables)
            plot_rf_feature_importance(
                input_feature_indices,
                base_model,
                os.path.join(args.output_path, "model_sensitivity_figures"),
            )
        except AttributeError:
            logger.info(
                f"Base model is {type(base_model).__name__}, "
                "which currently has no feature importance or Jacobian "
                "calculation implemented.")
            pass

    mapper = _get_data_mapper_if_exists(config)
    if mapper is not None:
        snapshot_time = (args.snapshot_time or sorted(
            config.kwargs.get("timesteps", list(mapper.keys())))[0])
        snapshot_key = nearest_time(snapshot_time, list(mapper.keys()))
        ds_snapshot = predict_function(mapper[snapshot_key])

        vertical_vars = [
            var for var in model.output_variables if is_3d(ds_snapshot[var])
        ]
        ds_snapshot = insert_column_integrated_vars(ds_snapshot, vertical_vars)
        predicted_vars = [
            var for var in ds_snapshot if "derivation" in ds_snapshot[var].dims
        ]

        # add snapshotted prediction to saved diags.nc
        ds_diagnostics = ds_diagnostics.merge(
            safe.get_variables(ds_snapshot, predicted_vars).rename(
                {v: f"{v}_snapshot"
                 for v in predicted_vars}))

        ds_transect = _get_transect(ds_snapshot, grid, vertical_vars)
        _write_nc(ds_transect, args.output_path, TRANSECT_NC_NAME)

    ds_diagnostics = _add_derived_diagnostics(ds_diagnostics)

    _write_nc(
        ds_diagnostics,
        args.output_path,
        DIAGS_NC_NAME,
    )

    # convert and output metrics json
    metrics = _average_metrics_dict(ds_scalar_metrics)
    with fsspec.open(os.path.join(args.output_path, METRICS_JSON_NAME),
                     "w") as f:
        json.dump(metrics, f, indent=4)

    metadata = {}
    metadata["model_path"] = args.model_path
    metadata["data_config"] = dataclasses.asdict(config)
    with fsspec.open(os.path.join(args.output_path, METADATA_JSON_NAME),
                     "w") as f:
        json.dump(metadata, f, indent=4)

    logger.info(f"Finished processing dataset diagnostics and metrics.")
예제 #14
0
def _load_wind_rotation_matrix(res: str) -> xr.Dataset:
    rotation = catalog[f"wind_rotation/{res}"].to_dask()
    return safe.get_variables(rotation, WIND_ROTATION_COEFFICIENTS)
예제 #15
0
def _load_grid(res: str) -> xr.Dataset:
    grid = catalog[f"grid/{res}"].to_dask()
    land_sea_mask = catalog[f"landseamask/{res}"].to_dask()
    grid = grid.assign({"land_sea_mask": land_sea_mask["land_sea_mask"]})
    # drop the tiles so that this is compatible with other indexing conventions
    return safe.get_variables(grid, ["lat", "lon", "land_sea_mask"]).drop("tile")
예제 #16
0
def load_grid_info(res: str = "c48"):
    grid = catalog[f"grid/{res}"].read()
    wind_rotation = catalog[f"wind_rotation/{res}"].read()
    land_sea_mask = catalog[f"landseamask/{res}"].read()
    grid_info = xr.merge([grid, wind_rotation, land_sea_mask])
    return safe.get_variables(grid_info, GRID_INFO_VARS).drop("tile")