import itertools
import numpy as np
from matplotlib.lines import Line2D

from base_plotting import PlotStructure
from wofs_ml.common.load_results import load_data

import wofs.plotting.plotting_config as plt_config
from wofs.plotting.Plot import Plotting
my_plt = Plotting(  )

import matplotlib.pyplot as plt

plt.rc('axes', facecolor='#E6E6E6')
plt.rc('lines', linewidth=0.8)

ml_model_names = ['RandomForest','XGBoost','LogisticRegression']
time = 'second_hour'
target_set = ['tornado', 'severe_hail', 'severe_wind']
normalize_method = {'tornado': 'standard', 'severe_hail' : None, 'severe_wind' : None}
resample_method  =  {'tornado': 'under', 'severe_hail' : None, 'severe_wind' : None}
imputer_method = 'simple'

line_colors = plt_config.colors_for_ml_models_dict
line_labels = ml_model_names
fig_fname = f'performance_diagram.png'

#################
base_plot = PlotStructure()
lead_times = np.arange(12)*5 + 60
def plot_pred(date, time, fti, target, **kwargs):
    """Worker func for plotting ML predictions"""
    examples = kwargs['examples']
    info = kwargs['info']
    time_marker = 'first_hour' if fti <= 12 else 'second_hour'

    rng = np.arange(0., 0.55, 0.05) if 'torn' in target else np.arange(
        0., 1.1, 0.1)

    kwargs = {'alpha': 0.95, 'extend': 'neither', 'cmap': target_cmap[target]}
    plt = Plotting(date=None,
                   z1_levels=rng,
                   z2_levels=[35, 60],
                   z3_levels=[35., 60],
                   z4_levels=[prob_threshold[target], 90.],
                   z4_color='k',
                   **kwargs)

    n_panels = 2
    fig, axes, = plt._create_fig(
        fig_num=0,
        sub_plots=(2, 1),
        plot_map=False,
        figsize=(7, 6),
        hspace=0.15,
        wspace=0.15,
    )
    map_axes, x, y = plt._generate_base_map(axes=axes, date=date, fig=fig)

    valid_date_and_time, initial_date_and_time = get_time.determine_forecast_valid_datetime(
        date_dir=str(date), time_dir=time, fcst_time_idx=fti)
    try:
        mrms_dbz = mrms.load_single_mrms_time(
            date_dir=date,
            valid_datetime=valid_date_and_time,
            var_name='DZ_CRESSMAN')
    except:
        return None

    try:
        ens_data = EnsembleData(date_dir=date,
                                time_dir=time,
                                base_path='summary_files')
        data = ens_data.load(variables=['comp_dz'],
                             time_indexs=[fti],
                             tag='ENS')
    except:
        return None
    wofs_dbz = data['comp_dz'].values[0]

    wofs_dbz = prob_match_mean(var=wofs_dbz,
                               mean_var=np.mean(wofs_dbz, axis=0),
                               neighborhood=15)

    init_time_str, valid_time_str = get_timestampes(initial_date_and_time,
                                                    valid_date_and_time)

    load_lsr = loadLSR(date_dir=date,
                       date=valid_date_and_time[0],
                       time=valid_date_and_time[1],
                       forecast_length=30)
    load_wwa = loadWWA(date_dir=date,
                       date=valid_date_and_time[0],
                       time=valid_date_and_time[1],
                       forecast_length=30,
                       time_window=30)

    hail_ll = load_lsr.load_hail_reports()
    torn_ll = load_lsr.load_tornado_reports()
    wind_ll = load_lsr.load_wind_reports()
    lsr_points = {'hail': hail_ll, 'tornado': torn_ll, 'wind': wind_ll}

    torn_wwa = load_wwa.load_tornado_warning_polygon(return_polygons=True)
    wwa_points = {'tornado': torn_wwa}

    ml_preds, objects = load_ml_predictions(time_marker,
                                            target,
                                            date,
                                            time,
                                            fti,
                                            model_names=['LogisticRegression'],
                                            examples=examples,
                                            info=info)

    baseline_pred = load_baseline(date, time, fti, target, time_marker)

    ml_preds, objects = smooth_objects(objects,
                                       ml_preds,
                                       model_names=['LogisticRegression'])

    if baseline_pred is None or ml_preds is None:
        return None

    all_preds = ml_preds + [baseline_pred]
    titles = ['LogisticRegression', baseline_names[target]]

    for j, forecast_probabilities in enumerate(all_preds):
        ax = axes.flat[j]
        if np.amax(forecast_probabilities.astype(float) *
                   100.) > prob_threshold[target]:
            z4 = forecast_probabilities.astype(float) * 100.
        else:
            z4 = None

        ###wwa_points=wwa_points,
        contours = plt.spatial_plotting(
            fig,
            ax,
            x,
            y,
            lsr_points=lsr_points,
            z1=np.ma.masked_where(forecast_probabilities == 0.,
                                  forecast_probabilities),
            z2=mrms_dbz,
            z3=wofs_dbz,
            z4=z4,
            map_ax=map_axes[j],
        )
        ax.set_title(titles[j], fontsize=12, pad=1.0, alpha=0.8)
        label_centroid(objects[0], forecast_probabilities, ax)

    axes.flat[0].text(
        0.0,
        1.17,
        f'Init Time : {init_time_str}',
        fontsize=8,
        alpha=0.9,
        transform=axes.flat[0].transAxes,
    )

    axes.flat[0].text(0.0,
                      1.12,
                      f'Valid Time : {valid_time_str}',
                      fontsize=8,
                      alpha=0.9,
                      transform=axes.flat[0].transAxes)

    base_plot.add_alphabet_label(n_panels, axes, pos=(0.9, 0.08))

    additional_handles = [
        Line2D([0], [0],
               marker='o',
               color='w',
               markerfacecolor='r',
               markersize=6,
               alpha=0.8),
        Line2D([0], [0],
               marker='o',
               color='w',
               markerfacecolor='b',
               markersize=6,
               alpha=0.8),
        Line2D([0], [0],
               marker='o',
               color='w',
               markerfacecolor='g',
               markersize=6,
               alpha=0.8),
        Line2D([0], [0], color='b', alpha=0.8),
        Line2D([0], [0], color='k', alpha=0.8),
    ]

    additional_labels = [
        'Tornado', 'Severe Wind', 'Severe Hail', 'WoFS PMM DBZ > 35',
        'MRMS DBZ > 35'
    ]

    major_ax = base_plot.set_major_axis_labels(fig,
                                               xlabel='',
                                               ylabel_left='',
                                               labelpad=5)

    #base_plot.set_legend(n_panels, fig, axes[0,0], major_ax, additional_handles, additional_labels, bbox_to_anchor=(0.425, -0.375), ncol=2)

    colorbar_labels = {
        'tornado': 'Probability of Tornado',
        'severe_hail': 'Probability of Severe Hail',
        'severe_wind': 'Probability of Severe Wind'
    }
    base_plot.add_colorbar(fig=fig,
                           plot_obj=contours,
                           ax=axes,
                           colorbar_label=colorbar_labels[target])

    fname = f'example_ml_predictions_vs_baseline_{date}_{time}_{target}_{time_marker}_{fti:02d}.png'
    plt._save_fig(fig=fig, fname=fname, dpi=300, aformat='png')
Example #3
0
                   date=valid_date_and_time[0],
                   time=valid_date_and_time[1],
                   forecast_length=30)
hail_ll = load_lsr.load_hail_reports()
torn_ll = load_lsr.load_tornado_reports()
wind_ll = load_lsr.load_wind_reports()
hail_xy, torn_xy, wind_xy = load_reports(str(date),
                                         valid_date_and_time,
                                         time_window=15,
                                         forecast_length=30)
lsr_points = {'hail': hail_ll, 'tornado': torn_ll, 'wind': wind_ll}

kwargs = {'alpha': 0.7, 'extend': 'neither', 'cmap': 'wofs'}
plt = Plotting(date=date,
               z1_levels=np.arange(0., 1.1, .1),
               z2_levels=[0.],
               z3_levels=[0.],
               **kwargs)

object_props = regionprops(probability_objects.astype(int),
                           prob_of_storm_location,
                           coordinates='rc')
verification_dict = {
    'severe_hail': hail_xy,
    'severe_wind': wind_xy,
    'tornado': torn_xy
}
matched_at_15km = {
    'matched_to_{}_15km'.format(atype):
    match_to_lsrs(object_properties=object_props,
                  lsr_points=verification_dict[atype],
    valid_dt = to_dt(valid_date_and_time)

    init_time_str = initial_dt.strftime("%Y-%m-%d, %H:%M UTC")
    valid_time_str = valid_dt.strftime("%Y-%m-%d, %H:%M")

    duration = valid_dt + timedelta(minutes=30)

    duration_time = duration.strftime("%H:%M")

    valid_time_str += '-' + duration_time + ' UTC'

    return init_time_str, valid_time_str


kwargs = {'alpha':0.7, 'extend': 'neither', 'cmap': 'wofs' }
plt = Plotting( date=None, z1_levels = np.arange(0., 1.1, .1), z2_levels = [0.], z3_levels=[0.], **kwargs )

nrows = len(target_set)
ncols = 1
fig, axes, = plt._create_fig(
                            fig_num = 0,
                            sub_plots =(nrows,ncols),
                            plot_map = False,
                            figsize = (7, 6),
                            hspace=0.2,
                            wspace=0.15,
                                )
i=0
contours_set = []
for target_var in target_set: 
    normalize = normalize_method[target_var]