def sink(self, time: cftime.DatetimeJulian, data: Mapping[str, xr.DataArray]): for variable in data: fig = plt.figure() data[variable].plot() tf.summary.image(f"{variable}", plot_to_image(fig), step=self.step) plt.close(fig) self.step += 1
def plot_time_heights( prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS ): prognostic, baseline = consistent_time_len(prognostic, baseline) for name in do_variables: fig = plot_global_avg_by_height_panel(prognostic[name], baseline[name]) wandb.log({f"avg_time_height/{name}": wandb.Image(plot_to_image(fig))}) plt.close(fig)
def plot_lat_heights( prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS ): prognostic, baseline = consistent_time_len(prognostic, baseline) ntimes = len(prognostic.time) start = max(ntimes - 8, 0) selection = slice(start, ntimes) prog_near_end = prognostic.isel(time=selection).mean(dim="time") base_near_end = baseline.isel(time=selection).mean(dim="time") lat = base_near_end["lat"] for name in do_variables: prog_zonal = zonal_average_approximate(lat, prog_near_end[name]) base_zonal = zonal_average_approximate(lat, base_near_end[name]) fig = plot_global_avg_by_height_panel(prog_zonal, base_zonal, x="lat") wandb.log({f"zonal_avg/{name}": wandb.Image(plot_to_image(fig))}) plt.close(fig)
def plot_global_means( prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS ): prognostic, baseline = consistent_time_len(prognostic, baseline) for name in do_variables: if name not in baseline or name not in prognostic: logger.info(f"Skipping global mean due to missing variable: {name}") continue fig, ax = plt.subplots() fig.set_dpi(80) da = prognostic[name] da.plot(ax=ax, label="Emulation") baseline[name].plot(ax=ax, label="Baseline", alpha=0.6) plt.legend() wandb.log({f"global_avg/{name}": wandb.Image(plot_to_image(fig))}) plt.close(fig)
def plot_transects( prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=TRANSECT_VARS ): tidx_map = {"start": 0, "near_end": len(prognostic.time) - 2} for time_name, tidx in tidx_map.items(): for name in do_variables: fig, ax = plt.subplots(1, 2) fig.set_size_inches(8, 4) fig.set_dpi(80) plot_meridional(prognostic.isel(time=tidx), name, ax=ax[0]) plot_meridional(baseline.isel(time=tidx), name, ax=ax[1]) ax[0].set_title("Emulation") ax[1].set_title("Baseline") plt.tight_layout() log_name = f"meridional_transect/{time_name}/{name}" wandb.log({log_name: wandb.Image(plot_to_image(fig))}) plt.close(fig)
def plot_spatial_comparisons( prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS, time_idxs: Mapping[str, int] = None, level_map: Mapping[str, int] = None, ): if time_idxs is None: time_idxs = {"start": 0, "near_end": len(prognostic.time) - 2} if level_map is None: level_map = {"lower": 75, "upper_BL": 60, "upper_atm": 20} for name in do_variables: for level, lev_idx in level_map.items(): for time, tidx in time_idxs.items(): prog = prognostic.isel(time=tidx, z=lev_idx) base = baseline.isel(time=tidx, z=lev_idx) fig = plot_spatial_2panel_with_diff(prog, base, name) log_name = f"spatial_comparison/{time}/{level}/{name}" wandb.log({log_name: wandb.Image(plot_to_image(fig))}) plt.close(fig)
def log_profiles(self, key, data, step): fig = plt.figure() plt.plot(data) wandb.log({key: wandb.Image(plot_to_image(fig))}, step=step) plt.close(fig)
def log_profiles(self, key, data, step): fig = plt.figure() plt.plot(data) tf.summary.image(key, plot_to_image(fig), step)
def log_map(ds, key): fv3viz.plot_cube(ds, key) metrics[key] = wandb.Image(plot_to_image(plt.gcf())) plt.close("all")