Esempio n. 1
0
def get_predictions_eval_test():
    schema_pred = "landed"
    schema_actuals = "preflight"

    table_ds = "ds_pgm_eval_test"
    table_osa = "osa_pgm_eval_test"
    table_ebma = "ebma_pgm_eval_test"

    table_actuals = "flight_pgm"

    timevar = "month_id"
    groupvar = "pg_id"
    ids = [timevar, groupvar]

    outcomes = [
        "ged_dummy_sb", "ged_dummy_ns", "ged_dummy_os", "acled_dummy_pr"
    ]

    df_ds = dbutils.db_to_df(connectstring, schema_pred, table_ds)
    df_osa = dbutils.db_to_df(connectstring, schema_pred, table_osa)
    df_ebma = dbutils.db_to_df(connectstring, schema_pred, table_ebma)
    df_ds.set_index(ids, inplace=True)
    df_osa.set_index(ids, inplace=True)
    df_ebma.set_index(ids, inplace=True)

    t_start_ds = df_ds.index.get_level_values(timevar).min()
    t_start_osa = df_osa.index.get_level_values(timevar).min()
    t_start_ebma = df_ebma.index.get_level_values(timevar).min()
    t_end_ds = df_ds.index.get_level_values(timevar).max()
    t_end_osa = df_osa.index.get_level_values(timevar).max()
    t_end_ebma = df_ebma.index.get_level_values(timevar).max()

    start_same = t_start_ds == t_start_osa == t_start_ebma
    end_same = t_end_ds == t_end_osa == t_end_ebma

    if not start_same and end_same:
        raise RuntimeError("The time indexes for ds, osa and ebma don't match")

    df = dbutils.db_to_df_limited(connectstring,
                                  schema_actuals,
                                  table_actuals,
                                  columns=outcomes + ids,
                                  timevar=timevar,
                                  groupvar=groupvar,
                                  tmin=t_start_ds,
                                  tmax=t_end_ds)

    df = df.merge(df_ds, left_index=True, right_index=True)
    df = df.merge(df_osa, left_index=True, right_index=True)
    df = df.merge(df_ebma, left_index=True, right_index=True)
    return df
Esempio n. 2
0
def fetch_df_country_names(connectstring):
    cols = ["id", "name"]
    df_names = dbutils.db_to_df(connectstring,
                                "staging",
                                "country",
                                columns=cols,
                                ids=["id"])
    return df_names
Esempio n. 3
0
def fetch_df_geo_pgm(connectstring):
    """Fetch df containing priogrid row/col and their lat/lon for subsetting"""
    schema = "staging"
    table = "priogrid"
    cols = ["gid", "row", "col", "latitude", "longitude", "in_africa"]
    df = dbutils.db_to_df(connectstring, schema, table, cols)
    df['pg_id'] = df['gid']
    del df['gid']
    df.set_index(['pg_id'], inplace=True)
    return df
Esempio n. 4
0
def fetch_df_months(connectstring):

    cols = ["id", "month", "year_id"]
    df = dbutils.db_to_df(connectstring, "staging", "month", columns=cols)

    df['datestr'] = df['year_id'].map(str) + "-" + df['month'].map(str)

    df.rename(columns={'id': 'month_id'}, inplace=True)
    df.set_index(['month_id'], inplace=True)

    df = df[['datestr']]

    return df
Esempio n. 5
0
def plot_lines_per_group_with_actuals(df, connectstring, directory):

    directory = directory + "/country_avg_w_actuals/"
    create_dirs([directory])

    groups = get_groups(df)
    columns = get_numeric_cols(df)

    df_actuals = dbutils.db_to_df(connectstring, "landed", "agg_cm_actuals")
    df_actuals.set_index(['month_id', 'country_id'], inplace=True)
    df_actuals = subset_times(df_actuals, *get_times_actuals(df))

    df_names = fetch_df_country_names(connectstring)
    df_months = fetch_df_months(connectstring)

    df = df.merge(df_actuals, left_index=True, right_index=True, how='outer')

    for g in groups:
        dg = df.xs(g, level=1)

        country_name = df_names.loc[g]['name']

        for c in columns:

            actual = maputils.match_plotvar_actual(c)
            actual = "pgm_" + actual

            path = directory + str(int(g)) + "_" + str(c) + ".png"
            plt.figure()

            # if max is below 1%, set that as the limit
            if dg[c].max() < 0.01 and dg[actual].max() < 0.01:
                plt.ylim([-0.0001, 0.01])
            # let the figure set its own limits
            else:
                pass

            plt.plot(dg[actual], label='History')
            plt.plot(dg[c], linestyle='--', label=c)

            plt.xticks(*make_ticks(df, df_months), rotation=90)

            title = "{}\n{}".format(country_name, c)
            title = country_name
            plt.title(title, loc='left')
            plt.legend()
            plt.tight_layout()
            plt.savefig(path)
            print("wrote", path)
            plt.close()
Esempio n. 6
0
def get_data(connectstring, level, runtype):
    """"""

    schema = "landed"

    timevar = "month_id"
    if level == "cm":
        groupvar = "country_id"
    elif level == "pgm":
        groupvar = "pg_id"
    ids = [timevar, groupvar]

    table = "_".join(["calibrated", level, runtype, "test"])
    df = dbutils.db_to_df(connectstring, schema, table, ids=ids)
    df.sort_index(inplace=True)

    return df
Esempio n. 7
0
run_id = args.run_id

timevar = "month_id"
groupvar = "pg_id"
outcomes = ["sb", "ns", "os"]
outcomes = ["ged_dummy_" + outcome for outcome in outcomes]
print("outcomes:", outcomes)

connectstring = dbutils.make_connectstring(prefix="postgresql",
                                           db="views",
                                           uname="VIEWSADMIN",
                                           hostname="VIEWSHOST",
                                           port="5432")

df_pgm = dbutils.db_to_df(connectstring,
                          schema="landed",
                          table="ensemble_pgm_fcast_test",
                          ids=[timevar, groupvar])
df_c = dbutils.db_to_df(connectstring,
                        schema="staging",
                        table="country",
                        columns=["id", "name"])
df_c.rename(columns={'id': 'country_id'}, inplace=True)

df_cpgm = dbutils.db_to_df(connectstring,
                           schema="staging_test",
                           table="cpgm",
                           ids=[timevar, groupvar])

df = df_pgm.merge(df_cpgm, left_index=True, right_index=True)
df.reset_index(inplace=True)
df = df.merge(df_c, on=["country_id"])
Esempio n. 8
0
timevar = "month_id"
groupvar = "pg_id"
schema_input = "launched"
table_input = "transforms_pgm_imp_1"
schema_output = "landed"
table_output = "rescaled_pgm"

ids = [timevar, groupvar]
vars_to_rescale = [var['name'] for var in rescales]
vars_rescaled = []
cols = ids + vars_to_rescale

df = dbutils.db_to_df(connectstring,
                      schema_input,
                      table_input,
                      columns=cols,
                      ids=ids)

for rescale in rescales:
    name_new = rescale['name'] + "_rescaled"
    vars_rescaled.append(name_new)
    rescale['opts'].update({'x': df[rescale['name']]})

    df[name_new] = scale_to_range(**rescale['opts'])

df = df[vars_rescaled]

dbutils.df_to_db(connectstring,
                 df,
                 schema_output,
Esempio n. 9
0
tables_cl = ["cl_pgm_eval_calib",
             "cl_pgm_eval_test",
             "cl_pgm_fcast_calib",
             "cl_pgm_fcast_test"]

tables_calibrated = ["calibrated_cm_eval_test",
                     "calibrated_cm_fcast_test",
                     "calibrated_pgm_eval_test",
                     "calibrated_pgm_fcast_test",]



for table in tables_osa:
    path = dir_results_osa + "/" + table + ".hdf5"
    df = dbutils.db_to_df(connectstring, schema, table)
    df.to_hdf(path, key='data', complevel=9)
    print("wrote", path)

for table in tables_ensemble:
    path = dir_results_ensemble + "/" + table + ".hdf5"
    df = dbutils.db_to_df(connectstring, schema, table)
    df.to_hdf(path, key='data', complevel=9)
    print("wrote", path)

for table in tables_agg:
    path = dir_results_agg + "/" + table + ".hdf5"
    df = dbutils.db_to_df(connectstring, schema, table)
    df.to_hdf(path, key='data', complevel=9)
    print("wrote", path)
Esempio n. 10
0
def plot_map_worker(local_settings, plotjob):

    connectstring = local_settings['connectstring']
    dir_plots = local_settings['dir_plots']
    dir_spatial_pgm = local_settings['dir_spatial_pgm']
    dir_spatial_cm = local_settings['dir_spatial_cm']

    plotvar         = plotjob['plotvar']
    varname_actual  = plotjob['varname_actual']
    schema_plotvar  = plotjob['schema_plotvar']
    schema_actual   = plotjob['schema_actual']
    table_plotvar   = plotjob['table_plotvar']
    table_actual    = plotjob['table_actual']
    timevar         = plotjob['timevar']
    groupvar        = plotjob['groupvar']
    variable_scale  = plotjob['variable_scale']
    projection      = plotjob['projection']
    crop            = plotjob['crop']
    run_id          = plotjob['run_id']

    path_shape_pg = dir_spatial_pgm + "/priogrid"
    path_shape_c = dir_spatial_cm + "/country"

    ids = [timevar, groupvar]

    print("Plotting variable {} from table {}.{}".format(plotvar,
        schema_plotvar, table_plotvar))

    print(json.dumps(plotjob, indent=4))

    df_plotvar = dbutils.db_to_df(connectstring, schema_plotvar, table_plotvar,
        [plotvar], ids)

    df_plotvar = restrict_prob_lower_bound(df_plotvar, 0.001)

    time_start, time_end = get_time_limits(df_plotvar, timevar)
    print("time_start:", time_start)
    print("time_end:", time_end)

    have_actual=False
    if varname_actual is not None:
        have_actual=True

    if groupvar == "pg_id":
        df_geo = fetch_df_geo_pgm(connectstring)
        df_geo = prune_priogrid(df_geo)
        size = get_figure_size(df_geo, scale=0.6)

    if have_actual and groupvar == "pg_id":
        df_actuals = dbutils.db_to_df_limited(connectstring, schema_actual,
                                               table_actual, [varname_actual],
                                               timevar, groupvar,
                                               time_start, time_end)

        df_event_coords = get_df_actuals_event_coords(df_actuals, df_geo,
                                                          varname_actual)


    elif groupvar == "country_id":
        df_geo = fetch_df_geo_c(crop)
        size = get_figure_size(df_geo, scale=0.6)

    if timevar == 'month_id':
        df_months = fetch_df_months(connectstring)

    plotvar_bounds = get_var_bounds(df_plotvar, plotvar)
    print("Bounds: ", plotvar_bounds)
    times = range(time_start, time_end+1)

    cmap = get_cmap(variable_scale)


    ticks = make_ticks(variable_scale)

    dir_schema  = dir_plots  + schema_plotvar + "/"
    dir_table   = dir_schema + table_plotvar  + "/"
    dir_plotvar = dir_table  + plotvar        + "/"
    create_dirs([dir_plots, dir_schema, dir_table, dir_plotvar])

    for t in times:
        print("Plotting for {}".format(t))
        df_plotvar_t = df_plotvar.loc[t]

        #fig = plt.figure(figsize = size)
        fig, ax = plt.subplots(figsize = size)

        print("Making basemap")
        map = get_basemap(projection, df_geo)

        print("Reading shape")

        if groupvar == "pg_id":
            map.readshapefile(path_shape_pg, 'GID', drawbounds=False)

        if groupvar == "country_id":
            map.readshapefile(path_shape_c, 'ID', drawbounds=False)

        if groupvar == "pg_id" and have_actual:
            events_t = get_events_t(df_event_coords, t)
            map = plot_events_on_map(map, events_t)

        print("Making collection")
        # Plot the probs
        collection = make_collection(map,
                                     df_plotvar_t,
                                     cmap,
                                     ticks['values'],
                                     variable_scale,
                                     plotvar_bounds,
                                     groupvar)
        ax.add_collection(collection)

        if groupvar == "pg_id":
           # The Africa limited shapefile
            map.readshapefile(path_shape_pg, 'GID', drawbounds=True)


        cbar_fontsize = size[1]/2

        if variable_scale in ["logodds", "prob"]:
            # if we're plotting logodds or probs set custom ticks
            cbar = plt.colorbar(collection,
                                ticks=ticks['values'],
                                fraction=0.046,
                                pad=0.04)
            cbar.ax.set_yticklabels(ticks['labels'], size=cbar_fontsize)
        else:
            # else use default colorbar for interval variables
            cbar = plt.colorbar(collection)

        map.drawmapboundary()

        if groupvar == "pg_id":
            map.readshapefile(path_shape_pg, 'GID', drawbounds=True)
            map.readshapefile(path_shape_c, 'ID', drawbounds=True, color='w',
                linewidth = 2)
            map.readshapefile(path_shape_c, 'ID', drawbounds=True, color='k',
                linewidth = 1)
        elif groupvar == "country_id":
            map.readshapefile(path_shape_c, 'ID', drawbounds=True)


        s_t_plotvar = "{}.{}".format(schema_plotvar, table_plotvar)
        text_box = "Modelname: {}\nRun: {}\nTable: {}".format(plotvar,
                                                                 run_id,
                                                                 s_t_plotvar)
        bbox = {'boxstyle' : 'square',  'facecolor' : "white"}
        lon_min = df_geo['longitude'].min()
        lon_max = df_geo['longitude'].max()
        lat_min = df_geo['latitude'].min()
        lat_max = df_geo['latitude'].max()

        w_eu = 6.705
        h_eu = 4.5
        w_erc = 4.7109375
        h_erc = 4.5
        w_views = w_eu + w_erc + 1
        h_views = 4.024

        lon_start_eu = lon_min + 1
        lon_end_eu = lon_start_eu + w_eu
        lon_start_erc = lon_end_eu + 1
        lon_end_erc = lon_start_erc + w_erc
        lon_start_views = lon_start_eu
        lon_end_views = lon_end_erc

        lat_start_eu = lat_min + 1
        lat_end_eu = lat_start_eu + h_eu
        lat_start_erc = lat_start_eu
        lat_end_erc = lat_start_erc + h_erc
        lat_start_views = lat_end_eu + 1
        lat_end_views = lat_start_views + h_views

        lon_min_textbox = lon_start_views
        lat_min_textbox = lat_end_views + 1

        box_logo_eu     = (lon_start_eu, lon_end_eu,
                           lat_start_eu, lat_end_eu)
        box_logo_erc    = (lon_start_erc, lon_end_erc,
                           lat_start_erc, lat_end_erc)
        box_logo_views  = (lon_start_views, lon_end_views,
                           lat_start_views, lat_end_views)

        plt.text(lon_min_textbox, lat_min_textbox,
                 text_box, bbox=bbox, fontsize=size[1]*0.5)

        logo_eu = mpimg.imread("/storage/static/logos/eu.png")
        logo_erc = mpimg.imread("/storage/static/logos/erc.png")
        logo_views = mpimg.imread("/storage/static/logos/views.png")

        plt.imshow(logo_erc, extent = box_logo_erc)
        plt.imshow(logo_eu, extent = box_logo_eu)
        plt.imshow(logo_views, extent = box_logo_views)

        text_title = "TITLE"
        if timevar == 'month_id':
            text_title = month_id_to_datestr(df_months, t)
        plt.figtext(0.5, 0.85, text_title, fontsize=size[1], ha='center')

        path = dir_plotvar + str(t) + ".png"
        plt.savefig(path, bbox_inches="tight")
        print("wrote", path)
        plt.close()
Esempio n. 11
0
table_ds = "ds_pgm_eval_test"
table_osa = "osa_pgm_eval_test"
table_ebma = "ebma_pgm_eval_test"

table_actuals = "flight_pgm"

timevar = "month_id"
groupvar = "pg_id"
ids = [timevar, groupvar]

outcomes = ["ged_dummy_sb", "ged_dummy_ns", "ged_dummy_os", "acled_dummy_pr"]
outcomes_suffix = [outcome[-3:] for outcome in outcomes]

try:
    df_ds = dbutils.db_to_df(connectstring, schema_pred, table_ds)
    df_osa = dbutils.db_to_df(connectstring, schema_pred, table_osa)
    df_ebma = dbutils.db_to_df(connectstring, schema_pred, table_ebma)
    df_ds.set_index(ids, inplace=True)
    df_osa.set_index(ids, inplace=True)
    df_ebma.set_index(ids, inplace=True)

    t_start_ds = df_ds.index.get_level_values(timevar).min()
    t_start_osa = df_osa.index.get_level_values(timevar).min()
    t_start_ebma = df_ebma.index.get_level_values(timevar).min()
    t_end_ds = df_ds.index.get_level_values(timevar).max()
    t_end_osa = df_osa.index.get_level_values(timevar).max()
    t_end_ebma = df_ebma.index.get_level_values(timevar).max()

    start_same = t_start_ds == t_start_osa == t_start_ebma
    end_same = t_end_ds == t_end_osa == t_end_ebma
Esempio n. 12
0
]

if not True in all_plot_bools:
    print("You didn't specify any plot types, exiting")
    sys.exit(1)

connectstring = dbutils.make_connectstring(db="views",
                                           hostname="VIEWSHOST",
                                           port="5432",
                                           prefix="postgres",
                                           uname="VIEWSADMIN")

dir_descriptive = "/storage/runs/current/descriptive"
dir_table = "/".join([dir_descriptive, schema, table])

df = dbutils.db_to_df(connectstring, schema, table, ids=[timevar, groupvar])
df.sort_index(inplace=True)

if plot_wawa:
    utils.plot_world_average_with_actuals(df, connectstring, dir_table,
                                          timevar, groupvar)
if plot_spaghetti:
    utils.plot_spaghetties(df, connectstring, dir_table)
if plot_hist:
    utils.plot_histograms(df, dir_table)
if plot_lpg:
    utils.plot_lines_per_group(df, dir_table)
if plot_abt:
    utils.plot_stats_by_time(df, dir_table)
if plot_pgcm:
    utils.plot_pgcm(df, connectstring, dir_table)
Esempio n. 13
0
        "ds_pgm_fcast_calib",
        "osa_pgm_fcast_calib"
    ],
    "eval_test": [
        "ds_pgm_eval_test",
        "osa_pgm_eval_test",
        "ensemble_pgm_eval_test"
    ],
    "eval_calib": [
        "ds_pgm_eval_calib",
        "osa_pgm_eval_calib"
    ]
}

# Get the country_id for each pgm
df_country_keys = dbutils.db_to_df(connectstring, "staging_test", "cpgm")
df_country_keys.set_index(["month_id", "pg_id"], inplace=True)

for time in times_tables:
    df = df_country_keys.copy()
    for table in times_tables[time]:
        print("Fetching {}".format(table))
        df_scratch = dbutils.db_to_df(connectstring, "landed", table, 
                          ids = ["month_id", "pg_id"])
        print("Merging {}".format(table))
        df = df.merge(df_scratch, left_index=True, right_index=True)
    print("Computing mean {}".format(time))
    df.reset_index(inplace=True)
    df = df.drop(columns=['pg_id'])
    df_mean = df.groupby(["month_id", "country_id"]).mean()
Esempio n. 14
0
def plot_spaghetties(df, connectstring, directory):
    directory = directory + "/spaghetti/"
    create_dirs([directory])

    lines = ["-", "--", "-.", ":"]
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    linecycler = itertools.cycle(lines)
    colorcycler = itertools.cycle(colors)

    #groups = get_groups(df)

    df_actuals = dbutils.db_to_df(connectstring, "landed", "agg_cm_actuals")
    df_actuals.set_index(['month_id', 'country_id'], inplace=True)
    df_actuals = subset_times(df_actuals, *get_times_actuals(df))

    df_months = fetch_df_months(connectstring)
    df_names = fetch_df_country_names(connectstring)

    df = df.merge(df_actuals, left_index=True, right_index=True, how='outer')

    columns = get_numeric_cols(df)
    for c in columns:
        path = directory + str(c) + ".png"

        global_median = df[c].median()

        # Plot the spaghetties with the highest mean country first
        df_means = df.groupby(level=1)[c].mean()
        df_means.sort_values(inplace=True, ascending=False)

        max_predicted_value = df[c].max()
        upper_limit_y = max_predicted_value * 1.2

        plt.figure(figsize=(24, 12))
        plt.ylim = (-0.001, upper_limit_y)

        for g in df_means.index:

            country_name = df_names.loc[g]['name']
            dg = df.xs(g, level=1)

            if dg[c].mean() > global_median:

                actual = maputils.match_plotvar_actual(c)
                actual = "pgm_" + actual

                ls = next(linecycler)
                color = next(colorcycler)
                plt.plot(dg[c],
                         alpha=0.5,
                         linestyle=ls,
                         color=color,
                         label=country_name)

                plt.plot(dg[actual],
                         alpha=0.5,
                         linestyle=ls,
                         color=color,
                         label=country_name)

        title = "{} - {}".format(g, c)
        plt.title(title)
        plt.legend(prop={'size': 8})
        plt.savefig(path)
        print("wrote", path)
        plt.close()
Esempio n. 15
0
def get_data(connectstring, runtype="fcast", period="calib"):
    """ Get actuals and calibration period and testing period predictions

    Args:
        connectstring:
        level: cm or pgm
        runtype: fcast or eval
    """
    def assert_equal_times(df1, df2, timevar):
        # Rewrite for list of dfs
        pass

    schema_link_ids = "staging_test"
    schema_predictions = "landed"

    timevar = "month_id"
    groupvar_c = "country_id"
    groupvar_pg = "pg_id"
    ids_c = [timevar, groupvar_c]
    ids_pg = [timevar, groupvar_pg]

    table_osa_pgm = "_".join(["osa", "pgm", runtype, period])
    table_ds_pgm = "_".join(["ds", "pgm", runtype, period])
    table_osa_cm = "_".join(["osa", "cm", runtype, period])
    table_ds_cm = "_".join(["ds", "cm", runtype, period])
    table_link_ids = "cpgm"

    df_osa_pgm = dbutils.db_to_df(connectstring,
                                  schema_predictions,
                                  table_osa_pgm,
                                  ids=ids_pg)
    df_ds_pgm = dbutils.db_to_df(connectstring,
                                 schema_predictions,
                                 table_ds_pgm,
                                 ids=ids_pg)
    df_osa_cm = dbutils.db_to_df(connectstring,
                                 schema_predictions,
                                 table_osa_cm,
                                 ids=ids_c)
    df_ds_cm = dbutils.db_to_df(connectstring,
                                schema_predictions,
                                table_ds_cm,
                                ids=ids_c)
    df_link_ids = dbutils.db_to_df(connectstring,
                                   schema_link_ids,
                                   table_link_ids,
                                   ids=ids_pg)

    for df in [df_osa_pgm, df_ds_pgm, df_osa_cm, df_ds_cm]:
        df.sort_index(inplace=True)

    df_pgm = df_osa_pgm.merge(df_ds_pgm, left_index=True, right_index=True)
    df_pgm = df_pgm.merge(df_link_ids, left_index=True, right_index=True)

    df_cm = df_osa_cm.merge(df_ds_cm, left_index=True, right_index=True)

    df_cm.reset_index(inplace=True)
    df_pgm.reset_index(inplace=True)

    df_pgm = df_pgm.merge(df_cm, on=[timevar, groupvar_c])
    df_pgm.set_index(ids_pg, inplace=True)
    df_pgm.drop(columns=[groupvar_c], inplace=True)
    return df_pgm
Esempio n. 16
0
sys.path.append("..")

import views_utils.dbutils as dbutils

connectstring = dbutils.make_connectstring(db="views",
                                           hostname="VIEWSHOST",
                                           port="5432",
                                           prefix="postgres",
                                           uname="VIEWSADMIN")

cols_actual = [
    "ged_dummy_sb", "ged_dummy_ns", "ged_dummy_os", "acled_dummy_pr"
]

# Get the country_id for each pgm
df_country_keys = dbutils.db_to_df(connectstring, "staging_test", "cpgm")
df_country_keys.set_index(["month_id", "pg_id"], inplace=True)

df = df_country_keys.copy()
df_actuals = dbutils.db_to_df(connectstring,
                              "preflight",
                              "flight_pgm",
                              ids=["month_id", "pg_id"],
                              columns=cols_actual)
df_actuals = df_actuals.add_prefix("pgm_")

df = df.merge(df_actuals, left_index=True, right_index=True)
df_mean = df.groupby(["month_id", "country_id"]).mean()

table_out = "agg_cm_actuals"
print("table_out: {}".format(table_out))
Esempio n. 17
0
def get_data(connectstring, level="cm", runtype="fcast"):
    """ Get actuals and calibration period and testing period predictions

    Args:
        connectstring:
        level: cm or pgm
        runtype: fcast or eval
    """

    schema_actuals = "preflight"
    schema_predictions = "landed"

    table_actuals = "flight_" + level
    cols_actual = ["ged_dummy_" + t for t in ["sb", "ns", "os"]]
    cols_actual.append("acled_dummy_pr")

    timevar = "month_id"
    if level == "cm":
        groupvar = "country_id"
    elif level == "pgm":
        groupvar = "pg_id"
    ids = [timevar, groupvar]

    table_osa_calib = "_".join(["osa", level, runtype, "calib"])
    table_osa_test = "_".join(["osa", level, runtype, "test"])
    table_ds_calib = "_".join(["ds", level, runtype, "calib"])
    table_ds_test = "_".join(["ds", level, runtype, "test"])
    table_cl_calib = "_".join(["cl", level, runtype, "calib"])
    table_cl_test = "_".join(["cl", level, runtype, "test"])



    df_osa_calib = dbutils.db_to_df(connectstring, schema_predictions,
                                    table_osa_calib, ids=ids)
    df_osa_test = dbutils.db_to_df(connectstring, schema_predictions,
                                   table_osa_test, ids=ids)

    df_ds_calib = dbutils.db_to_df(connectstring, schema_predictions,
                                   table_ds_calib, ids=ids)
    df_ds_test = dbutils.db_to_df(connectstring, schema_predictions,
                                  table_ds_test, ids=ids)

    # only include cl for pgm level
    if level=="pgm":
        df_cl_calib = dbutils.db_to_df(connectstring, schema_predictions,
                                       table_cl_calib, ids=ids)
        df_cl_test = dbutils.db_to_df(connectstring, schema_predictions,
                                      table_cl_test, ids=ids)


    assert_equal_times(df_osa_calib, df_ds_calib, timevar)
    assert_equal_times(df_osa_test, df_ds_test, timevar)

    t_start_calib = df_osa_calib.index.get_level_values(timevar).min()
    t_end_calib = df_osa_calib.index.get_level_values(timevar).max()

    df_pred_calib = df_osa_calib.merge(df_ds_calib,
                                       left_index=True, right_index=True)

    df_pred_test = df_osa_test.merge(df_ds_test,
                                     left_index=True, right_index=True)

    # only include cl for pgm level
    if level == "pgm":
        df_pred_calib = df_pred_calib.merge(df_cl_calib,
                                            left_index=True, right_index=True)
        df_pred_test = df_pred_test.merge(df_cl_test,
                                            left_index=True, right_index=True)

    df_actuals = dbutils.db_to_df_limited(connectstring,
                                          schema_actuals, table_actuals,
                                          columns=cols_actual,
                                          timevar=timevar,
                                          groupvar=groupvar,
                                          tmin=t_start_calib,
                                          tmax=t_end_calib)

    for df in [df_actuals, df_pred_calib, df_pred_test]:
        df.sort_index(inplace=True)


    return df_actuals, df_pred_calib, df_pred_test
Esempio n. 18
0
 def fetch_data(self):
     self.df_h = dbutils.db_to_df(connectstring=self.connectstring,
                                  schema=self.schema_h,
                                  table=self.table_h)