Exemplo n.º 1
0
def plot_rmse(*args, ref=None, area=None, title_font_size=0.4, y_max=None):
    """
    Plot RMSE curve
    """

    desc = []

    if not isinstance(ref, mv.Fieldset):
        raise Exception(f"Missing or invalid ref argument!")

    layers = _make_layers(*args, form_layout=False)

    # compute the rmse for each input layer
    data = []  # list of tuples
    rmse_data = []
    title_data = []
    has_ef = False

    for layer in layers:
        if isinstance(layer["data"], mv.Fieldset):
            # determine ens number
            members = layer["data"]._unique_metadata("number")
            # print(f"members={members}")
            # ens forecast
            if len(members) > 1:
                if has_ef:
                    raise Exception(
                        "Only one ENS fieldset can be used in plot_rmse()!")
                has_ef = True
                em_d = None  # ens mean
                for m in members:
                    pf_d = layer["data"].select(number=m)
                    ref_d, pf_d = _prepare_grid(ref, pf_d)
                    data.append(("cf" if m == "0" else "pf", layer["data"]))
                    rmse_data.append(mv.sqrt(mv.average((pf_d - ref_d)**2)))
                    em_d = pf_d if em_d is None else em_d + pf_d

                # compute rmse for ens mean
                data.append(("em", layer["data"]))
                rmse_data.append(
                    mv.sqrt(mv.average((em_d / len(members) - ref_d)**2)))

            # deterministic forecast
            else:
                ref_d, dd = _prepare_grid(ref, layer["data"])
                data.append(("fc", layer["data"]))
                rmse_data.append(mv.sqrt(mv.average((dd - ref_d)**2)))

            title_data.append(layer["data"])

    # define x axis params
    dates = ref.valid_date()
    x_min = dates[0]
    x_max = dates[-1]
    x_tick = 1
    x_title = ""

    # define y axis params
    y_min = 0
    if y_max is None:
        y_tick, _, y_max = Layout.compute_axis_range(0, _y_max(rmse_data))
    else:
        y_tick, _, _ = Layout.compute_axis_range(0, y_max)
    y_title = "RMSE [" + mv.grib_get_string(ref[0], "units") + "]"

    # print(f"y_tick={y_tick} y_max={y_max}")

    # define the view
    view = Layout().build_rmse(x_min, x_max, y_min, y_max, x_tick, y_tick,
                               x_title, y_title)
    desc.append(view)

    # define curves
    ef_label = {"cf": "ENS cf", "pf": "ENS pf", "em": "ENS mean"}
    ef_colour = {"cf": "black", "pf": "red", "em": "kelly_green"}
    fc_colour = [
        "red", "blue", "green", "black", "cyan", "evergreen", "gold", "pink"
    ]
    if has_ef:
        fc_colour = [x for x in fc_colour if x not in list(ef_colour.values())]

    pf_label_added = False
    colour_idx = -1
    legend_item_count = 0
    for i, d in enumerate(rmse_data):
        vis = mv.input_visualiser(input_x_type="date",
                                  input_date_x_values=dates,
                                  input_y_values=d)

        vd = {"graph_type": "curve"}
        line_colour = "black"
        line_width = 1

        if data[i][0] == "fc":
            line_width = 3
            colour_idx = (colour_idx + 1) % len(fc_colour)
            line_colour = fc_colour[colour_idx]
            # print(f"label={data[i][1][0].label}")
            vd["legend_user_text"] = data[i][1].label
            vd["legend"] = "on"
            legend_item_count += 1
        elif data[i][0] == "pf":
            line_width = 1
            line_colour = ef_colour["pf"]
            if not pf_label_added:
                pf_label_added = True
                vd["legend_user_text"] = ef_label.get("pf", "")
                vd["legend"] = "on"
                legend_item_count += 1
        elif data[i][0] in ["cf", "em"]:
            line_width = 3
            line_colour = ef_colour[data[i][0]]
            vd["legend_user_text"] = ef_label.get(data[i][0], "")
            vd["legend"] = "on"
            legend_item_count += 1

        vd["graph_line_colour"] = line_colour
        vd["graph_line_thickness"] = line_width

        desc.append(vis)
        desc.append(mv.mgraph(**vd))

    # add title
    title = Title(font_size=title_font_size)
    t = title.build_rmse(ref, title_data)
    if t is not None:
        desc.append(t)

    # add legend
    leg_left = 3.5
    # legY = 14
    leg_height = legend_item_count * (0.35 + 0.5) + (legend_item_count +
                                                     1) * 0.1
    leg_bottom = 17.5 - leg_height

    # Legend
    legend = mv.mlegend(
        legend_display_type="disjoint",
        legend_entry_plot_direction="column",  # "row",
        legend_text_composition="user_text_only",
        legend_border="on",
        legend_border_colour="black",
        legend_box_mode="positional",
        legend_box_x_position=leg_left,
        legend_box_y_position=leg_bottom,
        legend_box_x_length=4,
        legend_box_y_length=leg_height,
        legend_text_font_size=0.35,
        legend_box_blanking="on",
    )
    desc.append(legend)

    mv.plot(desc, animate=False)
Exemplo n.º 2
0
def test_sqrt_geopoints():
    sqrt_out = mv.sqrt(TEST_GEOPOINTS)
    maximum = mv.maxvalue(sqrt_out)
    assert mv.type(sqrt_out) == 'geopoints'
    assert np.isclose(maximum, MAX_SQRT_GPT)
Exemplo n.º 3
0
UC-05. The Analyst retrieves certain parameters, computes a derived
parameter and plots its values for a specific time.

--------------------------------------------------------------------------------
1. Analyst filters, the chosen parameters of a file from a given
   source (i.e. MARS, files, observation databases)
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
2. Analyst computes the desired derived parameter (i.e. wind velocity from zonal
   and meridional components)
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
3. Analyst plots the derived parameter values
--------------------------------------------------------------------------------
"""

import metview as mv

uv = mv.read('./uv.grib')

u = mv.read(data=uv, param='u')
v = mv.read(data=uv, param='v')

wind_speed = mv.sqrt(u * u + v * v)

mv.setoutput(mv.png_output(output_width=1000, output_name='./uvplot'))
mv.plot(wind_speed)
Exemplo n.º 4
0
def test_sqrt():
    sqrt_fd = mv.sqrt(TEST_FIELDSET)
    maximum = mv.maxvalue(sqrt_fd)
    assert np.isclose(maximum, np.sqrt(MAX_VALUE))