コード例 #1
0
    def sink(self, time: cftime.DatetimeJulian, data: Mapping[str,
                                                              xr.DataArray]):
        for variable in data:
            fig = plt.figure()
            data[variable].plot()
            tf.summary.image(f"{variable}", plot_to_image(fig), step=self.step)
            plt.close(fig)

        self.step += 1
コード例 #2
0
def plot_time_heights(
    prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS
):

    prognostic, baseline = consistent_time_len(prognostic, baseline)

    for name in do_variables:
        fig = plot_global_avg_by_height_panel(prognostic[name], baseline[name])
        wandb.log({f"avg_time_height/{name}": wandb.Image(plot_to_image(fig))})
        plt.close(fig)
コード例 #3
0
def plot_lat_heights(
    prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS
):

    prognostic, baseline = consistent_time_len(prognostic, baseline)
    ntimes = len(prognostic.time)
    start = max(ntimes - 8, 0)
    selection = slice(start, ntimes)
    prog_near_end = prognostic.isel(time=selection).mean(dim="time")
    base_near_end = baseline.isel(time=selection).mean(dim="time")
    lat = base_near_end["lat"]

    for name in do_variables:
        prog_zonal = zonal_average_approximate(lat, prog_near_end[name])
        base_zonal = zonal_average_approximate(lat, base_near_end[name])

        fig = plot_global_avg_by_height_panel(prog_zonal, base_zonal, x="lat")
        wandb.log({f"zonal_avg/{name}": wandb.Image(plot_to_image(fig))})
        plt.close(fig)
コード例 #4
0
def plot_global_means(
    prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=COMPARE_VARS
):

    prognostic, baseline = consistent_time_len(prognostic, baseline)

    for name in do_variables:
        if name not in baseline or name not in prognostic:
            logger.info(f"Skipping global mean due to missing variable: {name}")
            continue

        fig, ax = plt.subplots()
        fig.set_dpi(80)

        da = prognostic[name]
        da.plot(ax=ax, label="Emulation")
        baseline[name].plot(ax=ax, label="Baseline", alpha=0.6)
        plt.legend()

        wandb.log({f"global_avg/{name}": wandb.Image(plot_to_image(fig))})
        plt.close(fig)
コード例 #5
0
def plot_transects(
    prognostic: xr.Dataset, baseline: xr.Dataset, do_variables=TRANSECT_VARS
):

    tidx_map = {"start": 0, "near_end": len(prognostic.time) - 2}

    for time_name, tidx in tidx_map.items():
        for name in do_variables:
            fig, ax = plt.subplots(1, 2)
            fig.set_size_inches(8, 4)
            fig.set_dpi(80)

            plot_meridional(prognostic.isel(time=tidx), name, ax=ax[0])
            plot_meridional(baseline.isel(time=tidx), name, ax=ax[1])

            ax[0].set_title("Emulation")
            ax[1].set_title("Baseline")
            plt.tight_layout()

            log_name = f"meridional_transect/{time_name}/{name}"
            wandb.log({log_name: wandb.Image(plot_to_image(fig))})
            plt.close(fig)
コード例 #6
0
def plot_spatial_comparisons(
    prognostic: xr.Dataset,
    baseline: xr.Dataset,
    do_variables=COMPARE_VARS,
    time_idxs: Mapping[str, int] = None,
    level_map: Mapping[str, int] = None,
):

    if time_idxs is None:
        time_idxs = {"start": 0, "near_end": len(prognostic.time) - 2}

    if level_map is None:
        level_map = {"lower": 75, "upper_BL": 60, "upper_atm": 20}

    for name in do_variables:
        for level, lev_idx in level_map.items():
            for time, tidx in time_idxs.items():
                prog = prognostic.isel(time=tidx, z=lev_idx)
                base = baseline.isel(time=tidx, z=lev_idx)
                fig = plot_spatial_2panel_with_diff(prog, base, name)

                log_name = f"spatial_comparison/{time}/{level}/{name}"
                wandb.log({log_name: wandb.Image(plot_to_image(fig))})
                plt.close(fig)
コード例 #7
0
 def log_profiles(self, key, data, step):
     fig = plt.figure()
     plt.plot(data)
     wandb.log({key: wandb.Image(plot_to_image(fig))}, step=step)
     plt.close(fig)
コード例 #8
0
 def log_profiles(self, key, data, step):
     fig = plt.figure()
     plt.plot(data)
     tf.summary.image(key, plot_to_image(fig), step)
コード例 #9
0
def log_map(ds, key):
    fv3viz.plot_cube(ds, key)
    metrics[key] = wandb.Image(plot_to_image(plt.gcf()))
    plt.close("all")