Exemplo n.º 1
0
def add_table(length, quality, views, pages, dates, lang, ax: Axes) -> Table:
    """Add a table to figure aunder the plot to display values from all input files"""
    table_height = 4 if lang == 'sv' else 3
    table_width = len(length)
    table_text = [[f"{x:,}".replace(',', ' ')
                   for x in length[:-3]] + length[-3:],
                  [f"{x:,}".replace(',', ' ')
                   for x in views[:-3]] + views[-3:],
                  [f"{x:,}".replace(',', ' ')
                   for x in pages[:-3]] + pages[-3:]]
    row_labels = ['Length', 'Views', 'Pages']
    if lang == 'sv':
        table_text.insert(1,
                          [f"{x:,}".replace(',', ' ')
                           for x in quality[:-3]] + quality[-3:])
        row_labels.insert(1, 'Quality')

    data_table = table(ax,
                       cellText=table_text,
                       rowLabels=row_labels,
                       colLabels=[d.strftime('%-d %b %Y') for d in dates] +
                       ['    Change    ', 'Change in %', 'Average / page'],
                       bbox=[0, -0.7, 0.9, 0.5],
                       loc='bottom')
    data_table[table_height, len(length) - 1].get_text().set_text('-')

    for col in range(table_width):
        data_table[0, col].set_lw(3.0)
    for row in range(1, table_height + 1):
        data_table[row, -1].set_lw(3.0)
    data_table.auto_set_column_width(range(table_width))
    for row in range(1, table_height + 1):
        col_nr = table_width - 3
        cell_text = data_table[row, col_nr].get_text()
        if '+' in cell_text.get_text():
            data_table[row, col_nr].set_facecolor('lightgreen')
            data_table[row, col_nr + 1].set_facecolor('lightgreen')
            data_table[row, col_nr].set_fill(True)
            data_table[row, col_nr + 1].set_fill(True)
        elif '-' in cell_text.get_text():
            data_table[row, col_nr].set_facecolor('lightcoral')
            data_table[row, col_nr + 1].set_facecolor('lightcoral')
            data_table[row, col_nr].set_fill(True)
            data_table[row, col_nr + 1].set_fill(True)
    return data_table
Exemplo n.º 2
0
def plot_aux_data(soln, config, log, scale):
    """
    Plot auxiliary data such as energy distribution and receiver functions.

    :param soln: Solution container
    :type soln: Customized scipy.optimize.OptimizeResult
    :param config: Solution configuration
    :type config: dict
    :param log: Logging instance
    :type log: logging.Logger
    :param scale: Overall image scaling factor
    :type scale: float
    :return: Matplotlib figure containing the plotted data
    :rtype: matplotlib.figure.Figure
    """
    f = plt.figure(constrained_layout=False,
                   figsize=(6.4 * scale, 6.4 * scale))
    f.suptitle(config["station_id"], y=0.96, fontsize=16)
    gs = f.add_gridspec(2,
                        1,
                        left=0.1,
                        right=0.9,
                        bottom=0.1,
                        top=0.87,
                        hspace=0.3,
                        wspace=0.3,
                        height_ratios=[1, 2])
    gs_top = gs[0].subgridspec(1, 2)
    ax0 = f.add_subplot(gs_top[0, 0])
    ax1 = f.add_subplot(gs_top[0, 1])

    hist_alpha = 0.5
    soln_alpha = 0.3
    axis_font_size = 6 * scale
    title_font_size = 6 * scale
    nbins = 100

    # Plot energy distribution of samples and solution clusters
    energy_hist, bins = np.histogram(soln.sample_funvals, bins=nbins)
    energy_hist = energy_hist.astype(float) / np.max(energy_hist)
    ax0.bar(bins[:-1],
            energy_hist,
            width=np.diff(bins),
            align='edge',
            color='#808080',
            alpha=hist_alpha)

    for i, cluster_energies in enumerate(soln.cluster_funvals):
        color = 'C' + str(i)
        cluster_hist, _ = np.histogram(cluster_energies, bins)
        cluster_hist = cluster_hist.astype(float) / np.max(cluster_hist)
        ax0.bar(bins[:-1],
                cluster_hist,
                width=np.diff(bins),
                align='edge',
                color=color,
                alpha=soln_alpha)
    # end for
    ax0.set_title(
        'Energy distribution of random samples and solution clusters',
        fontsize=title_font_size)
    ax0.set_xlabel('$E_{SU}$ energy (arb. units)')
    ax0.set_ylabel('Normalized counts')
    ax0.tick_params(labelsize=axis_font_size)
    ax0.xaxis.label.set_size(axis_font_size)
    ax0.yaxis.label.set_size(axis_font_size)

    # Plot sorted per-event upwards S-wave energy at top of mantle per solution.
    # Collect event IDs of worst fit traces and present as table of waveform IDs.
    event_ids = config["event_ids"]
    events_best3 = []
    events_worst3 = []
    for i, esu in enumerate(soln.esu):
        assert len(esu) == len(event_ids)
        color = 'C' + str(i)
        esu_sorted = sorted(zip(esu, event_ids))
        events_best3.extend(esu_sorted[:3])
        events_worst3.extend(esu_sorted[-3:])
        esu_sorted = [e[0] for e in esu_sorted]
        ax1.plot(esu_sorted, color=color, alpha=soln_alpha)
    # end for
    events_best3 = sorted(events_best3)
    events_worst3 = sorted(events_worst3, reverse=True)
    best_events_set = set()
    worst_events_set = set()
    for _, evid in events_best3:
        best_events_set.add(evid)
        if len(best_events_set) >= 3:
            break
        # end if
    # end for
    for _, evid in events_worst3:
        worst_events_set.add(evid)
        if len(worst_events_set) >= 3:
            break
        # end if
    # end for
    _tab1 = table(ax1,
                  cellText=[[e] for e in best_events_set],
                  colLabels=['BEST'],
                  cellLoc='left',
                  colWidths=[0.35],
                  loc='upper left',
                  edges='horizontal',
                  fontsize=8,
                  alpha=0.6)  # alpha broken in matplotlib.table!
    _tab2 = table(ax1,
                  cellText=[[e] for e in worst_events_set],
                  colLabels=['WORST'],
                  cellLoc='left',
                  colWidths=[0.35],
                  loc='upper right',
                  edges='horizontal',
                  fontsize=8,
                  alpha=0.6)
    ax1.set_title('Ranked per-event energy for each solution point',
                  fontsize=title_font_size)
    ax1.set_xlabel('Rank (out of # source events)')
    ax1.set_ylabel('Event $E_{SU}$ energy (arb. units)')
    ax1.tick_params(labelsize=axis_font_size)
    ax1.xaxis.label.set_size(axis_font_size)
    ax1.yaxis.label.set_size(axis_font_size)

    # Plot receiver function at base of selected layers
    axis_font_size = 6 * scale
    max_solutions = config["solver"].get("max_solutions", 3)
    for layer in config["layers"]:
        lname = layer["name"]
        if soln.subsurface and lname in soln.subsurface:
            base_seismogms = soln.subsurface[lname]
            # Generate RF and plot.
            gs_bot = gs[1].subgridspec(max_solutions, 1, hspace=0.4)
            for i, seismogm in enumerate(base_seismogms):
                soln_rf = _compute_rf(seismogm, config, log)
                assert isinstance(soln_rf, rf.RFStream)
                # Remove any traces for which deconvolution failed.
                # First, find their unique ID. Then remove all traces with that ID.
                exclude_ids = set(
                    [tr.stats.event_id for tr in soln_rf if len(tr) == 0])
                soln_rf = rf.RFStream([
                    tr for tr in soln_rf
                    if tr.stats.event_id not in exclude_ids
                ])
                axn = f.add_subplot(gs_bot[i])
                if soln_rf:
                    color = 'C' + str(i)
                    rf_R = soln_rf.select(component='R').trim2(
                        RF_TRIM_WINDOW[0], RF_TRIM_WINDOW[1], reftime='onset')
                    num_RFs = len(rf_R)
                    times = rf_R[0].times() + RF_TRIM_WINDOW[0]
                    data = rf_R.stack()[0].data
                    axn.plot(times,
                             data,
                             color=color,
                             alpha=soln_alpha,
                             linewidth=2)
                    axn.text(0.95,
                             0.95,
                             'N = {}'.format(num_RFs),
                             fontsize=10,
                             ha='right',
                             va='top',
                             transform=axn.transAxes)
                    axn.set_xlabel('Time (sec)')
                    axn.grid(color='#80808080', linestyle=':')
                else:
                    axn.annotate('Empty RF plot', (0.5, 0.5),
                                 xycoords='axes fraction',
                                 ha='center')
                # end if
                axn.set_title(' '.join([
                    config["station_id"], lname, 'base RF',
                    '(soln {})'.format(i)
                ]),
                              fontsize=title_font_size,
                              y=0.92,
                              va='top')
                axn.tick_params(labelsize=axis_font_size)
                axn.xaxis.label.set_size(axis_font_size)
                axn.yaxis.label.set_size(axis_font_size)
            # end for
            break  # TODO: Figure out how to add more layers if needed
        # end if
    # end for

    return f
Exemplo n.º 3
0
"""
Ref:
    https://stackoverflow.com/questions/35634238/how-to-save-a-pandas-dataframe-table-as-a-png/35715029
    https://stackoverflow.com/questions/26678467/export-a-pandas-dataframe-as-a-table-image/26681726
    https://stackoverflow.com/questions/35634238/how-to-save-a-pandas-dataframe-table-as-a-png

"""
import matplotlib

matplotlib.use('Agg')
from matplotlib.pyplot import figure
from matplotlib.table import table
from pylab import *

fig = figure()

colLabels = ('Freeze', 'Wind', 'Flood', 'Quake', 'Hail')
rowLabels = ['%d year' % x for x in (100, 50, 20, 10, 5)]
cellText = [['66.4', '174.3', '75.1', '577.9', '32.0'], ['124.6', '555.4', '153.2', '677.2', '192.5'], ['213.8', '636.0', '305.7', '1175.2', '796.0'], ['292.2', '717.8', '456.4', '1368.5', '865.6'], ['431.5', '1049.4', '799.6', '2149.8', '917.9']]
#table(cellText=cellText, colLabels=colLabels)

ax = subplot(111, frame_on=False) 
ax.xaxis.set_visible(False) 
ax.yaxis.set_visible(False) 

table(cellText=cellText, colLabels=colLabels) 

fig.savefig('test12.png')
Exemplo n.º 4
0
    def plot_tables_planning_time_curve(self, res, res_1, simple, horizon, x,
                                        y, hue, execs):
        print("----- ", y)
        data = res[['id', x, y, hue]]
        if not self.compare_3T:
            if len(data[x].unique()) < 8:
                fig = plt.figure(figsize=(max(len(data[x].unique()), 8) - 3,
                                          7))
            else:
                fig = plt.figure(figsize=(max(len(data[x].unique()), 8) - 4,
                                          7))
        else:
            fig = plt.figure(figsize=(max(len(data[x].unique()), 8), 8))
        ax = plt.gca()
        # sns.set_context("paper", rc={"font.size":16,"axes.titlesize":16,"axes.labelsize":16})

        if horizon == min(self.horizons) or self.compare_3T:
            ax = sns.lineplot(x=x,
                              y=y,
                              hue=hue,
                              style="algorithm",
                              markers=True,
                              data=data)
            plt.setp(ax.get_legend().get_texts(),
                     fontsize='22')  # for legend text
            plt.setp(ax.get_legend().get_title(),
                     fontsize='22')  # for legend title
        else:
            ax = sns.lineplot(x=x,
                              y=y,
                              hue=hue,
                              style="algorithm",
                              markers=True,
                              data=data,
                              legend=False)  #, dashes=False

        # if horizon != 2:
        # ax._legend.remove()
        # plt.annotate('actual group', xy=(x+0.2,y), xytext=(x+0.3, 300), arrowprops=dict(facecolor='black', shrink=0.05, headwidth=20, width=7))
        # tab_names = data[x].unique()
        # alg_names = data[hue].unique()
        # for alg in range(len(alg_names)):
        # 	dat_temp = data.loc[data[hue] == alg_names[alg]]
        # 	mean_x = np.mean(dat_temp.loc[dat_temp[x] == tab_names[len(tab_names)-alg-1]])
        # 	pos = (tab_names[len(tab_names)-alg-1],mean_x[y])
        # 	set_trace()
        # 	plt.annotate(alg_names[alg][0], xy=(pos[0], pos[1]), xytext=(pos[0]+0.5, pos[1]+0.5), arrowprops=dict(facecolor='black', shrink=0.05),)

        if y != "final_total_belief_reward" and res_1 is not None and not res_1.empty:
            data_1 = res_1[[hue, x, y]]
            # table(ax, data_1, loc='upper left')
            # set_trace()

            # set_trace()
            tab_names = data_1[x].unique()
            alg_names = data_1[hue].unique()
            table_arr = np.full((len(tab_names) + 1, len(alg_names) + 1),
                                "",
                                dtype=object)
            # print ("algorithms: ", alg_names)
            # print ("tables: ", tab_names)
            for t in range(len(tab_names)):
                table_arr[t + 1, 0] = "tables = " + str(tab_names[t])
                for alg in range(len(alg_names)):
                    table_arr[0, alg + 1] = alg_names[alg][0]
                    sample = data_1.loc[data_1[hue] == alg_names[alg]]
                    sample = sample.loc[sample[x] == tab_names[t]]
                    # print(sample)
                    if not sample.empty:
                        table_arr[t + 1, alg +
                                  1] = str(np.round(float(sample[y]), 1)) + "s"

            if not self.compare_3T or simple == 1:
                if horizon == 2 or (simple == 1 and horizon == 5):
                    tab = table(ax,
                                cellText=table_arr,
                                cellLoc='center',
                                rowLoc='center',
                                loc='upper left',
                                bbox=[0.1, .2, .7,
                                      .3])  ## left, bottom, width, height
                else:
                    if len(data[x].unique()) < 10:
                        tab = table(ax,
                                    cellText=table_arr,
                                    cellLoc='center',
                                    rowLoc='center',
                                    loc='upper left',
                                    bbox=[0.0, .7, 1.0,
                                          .3])  ## left, bottom, width, height
                    else:
                        tab = table(ax,
                                    cellText=table_arr,
                                    cellLoc='center',
                                    rowLoc='center',
                                    loc='upper left',
                                    bbox=[0.0, .7, 1.0,
                                          .3])  ## left, bottom, width, height
            else:
                tab = table(ax,
                            cellText=table_arr,
                            cellLoc='center',
                            rowLoc='center',
                            loc='upper left',
                            bbox=[0.1, .4, .9,
                                  .2])  ## left, bottom, width, height

            tab.auto_set_font_size(False)
            tab.set_fontsize(24)
            # tab.scale(5, 5)

        ax.set_title("horizon: " + str(horizon), fontsize=26)
        # set_trace()
        matplt.xticks(
            np.arange(self.tables[0], max(max(data[x].unique()) + 1, 9)))

        if "time" in y:
            ylab = "planning time (s)"
        if "reward" in y:
            ylab = "reward"
        ax.set_xlabel(x, fontsize=26)
        ax.set_ylabel(ylab, fontsize=26)
        ax.tick_params(labelsize=22)

        agent_pomdp_max_table = data.loc[data[hue] == "B:Agent POMDP"]
        ax.get_xticklabels()[len(agent_pomdp_max_table[x].unique()) -
                             1].set_color("orange")

        # ax.tick_params(direction='out', length=6, width=2, colors='r', grid_color='r', grid_alpha=0.5)

        plt.tight_layout()
        # set_trace()
        matplt.savefig(self.POMDPTasks.test_folder + '/results_' + str(execs) + '/' + "3T_" + str(self.compare_3T) + "_" + x + "_" + hue + "_" + y + '_simple-' \
         + str(simple) + '_horizon-' + str(horizon) + "_execs-" + str(execs) + '.png')
    r'\# model runs'
]
rowLoc = 'right'

# cell_text = list(cols.flatten().astype(str))
cols = np.append(cols, cols.sum(axis=1).reshape((-1, 1)), axis=1)
cell_text = cols.astype(str)
cell_text[0, -1] = 'Total'
cell_text = cell_text[[0, 2, 3], :]
est1_ms.model_runs = int(cell_text[-1, -1])

table = tbl.table(ax[1],
                  cellText=cell_text,
                  cellLoc=cellLoc,
                  cellColours=colors,
                  colWidths=colWidths,
                  rowLabels=rowLabels,
                  rowLoc=rowLoc,
                  fontsize=(config.TICK_FONTSIZE - 2) * config.SCALE,
                  loc='center')
table.scale(1, 3)

for cell in table._cells:
    table._cells[cell].set_alpha(0.1)

ax[1].add_table(table)
plt.subplots_adjust(hspace=0.4)

# print(cols)

ax[0] = fp.add_title(ax[0],
Exemplo n.º 6
0
def plt_figure(directory, trace, peaks_vals, spacer, pos_in_call, filename,
               pos_list, base_pos_in_spacer, strand):
    # Plot the trace and table showing % of each base along the spacer

    fig, (ax1, ax2) = plt.subplots(2, figsize=(4, 2.5))
    fig.patch.set_visible(False)
    ax1 = plt.subplot2grid((2, 1), (0, 0))
    ax2 = plt.subplot2grid((2, 1), (1, 0))
    Arialfont = {'fontname': 'Arial'}
    ax1.plot(trace['G'][peaks_vals[pos_in_call] -
                        8:peaks_vals[pos_in_call + len(spacer) - 1] + 8],
             color='black',
             label='G',
             linewidth=0.6)
    ax1.plot(trace['A'][peaks_vals[pos_in_call] -
                        8:peaks_vals[pos_in_call + len(spacer) - 1] + 8],
             color='g',
             label='A',
             linewidth=0.6)  # A
    ax1.plot(trace['T'][peaks_vals[pos_in_call] -
                        8:peaks_vals[pos_in_call + len(spacer) - 1] + 8],
             color='r',
             label='T',
             linewidth=0.6)  # T
    ax1.plot(trace['C'][peaks_vals[pos_in_call] -
                        8:peaks_vals[pos_in_call + len(spacer) - 1] + 8],
             color='b',
             label='C',
             linewidth=0.6)  # C
    ax1.annotate('G',
                 xy=(10, 60),
                 xycoords='axes points',
                 size=5,
                 color='black',
                 weight='bold',
                 ha='right',
                 va='top')
    ax1.annotate('A',
                 xy=(14, 60),
                 xycoords='axes points',
                 size=5,
                 color='g',
                 weight='bold',
                 ha='right',
                 va='top')
    ax1.annotate('T',
                 xy=(18, 60),
                 xycoords='axes points',
                 size=5,
                 color='r',
                 weight='bold',
                 ha='right',
                 va='top')
    ax1.annotate('C',
                 xy=(22, 60),
                 xycoords='axes points',
                 size=5,
                 color='b',
                 weight='bold',
                 ha='right',
                 va='top')
    ax1.axis('off')
    ax1.axis('tight')
    ax1.set_title(filename.split(".")[0], fontsize=8, **Arialfont)

    # Put a legend to the right of the current axis
    # plt.legend([a_plot, c_plot, g_plot, t_plot], ['A', 'C', 'G', 'T'], bbox_to_anchor=(0.98,0.5), loc="center left", prop={'size': 6}, borderaxespad=0., ncol=1)

    ax2.axis('off')
    ax2.axis('tight')
    df = pd.DataFrame(pos_list, columns=list(spacer))
    xx = (df.values * 2.56).astype(int)
    c = mcolors.ColorConverter().to_rgb
    mycm = make_colormap([
        c('white'),
        c('yellow'), 0.3,
        c('yellow'),
        c('red'), 0.60,
        c('red'),
        c('green'), 0.8,
        c('green')
    ])

    # colours=plt.cm.CMRmap(xx)
    colours = plt.cm.get_cmap(mycm)(xx)
    stats_table = table(ax2,
                        cellText=df.values,
                        rowLabels='GATC',
                        colLabels=df.columns,
                        loc='center',
                        cellColours=colours)
    stats_table._cells[(0, base_pos_in_spacer - 1)]._text.set_color('red')
    stats_table._cells[(2, -1)]._text.set_color('green')
    stats_table._cells[(3, -1)]._text.set_color('red')
    stats_table._cells[(4, -1)]._text.set_color('blue')
    # stats_table._cells[(1,base_pos_in_spacer-1)].set_facecolor('red')

    for (row, col), cell in stats_table.get_celld().items():
        cell.set_linewidth(0)
        if (col == -1):
            cell.set_text_props(fontproperties=FontProperties(
                family='Arial', size=8, weight='bold'))
        elif (row == 0):
            cell.set_text_props(
                fontproperties=FontProperties(family='Arial', size=8))
        else:
            cell.set_text_props(
                fontproperties=FontProperties(family='Arial', size=6))

            # shrink current axis
    box = ax2.get_position()
    ax2.set_position(
        [box.x0 + 0.038, box.y0 + 0.15, box.width * 0.9, box.height * 0.8])
    # ax2.arrow(0, 20, 10, 47, head_width=1, head_length=2, fc='k', ec='k')
    if (strand == "+"):
        ax2.annotate("5'>",
                     xy=(0, 47),
                     xycoords='axes points',
                     size=5,
                     color='k',
                     weight='bold',
                     ha='right',
                     va='top')
    else:
        ax2.annotate("3'<",
                     xy=(0, 47),
                     xycoords='axes points',
                     size=5,
                     color='k',
                     weight='bold',
                     ha='right',
                     va='top')

    figfile = filename.split('.')[0] + '.png'
    fig.savefig(directory + '/result/' + figfile, dpi=300)
    plt.close()
Exemplo n.º 7
0
                    ],
                                  [
                                      'Maximum Head above\nLand Surface (ft)',
                                      str(round(abs(max_dtw), 2))
                                  ],
                                  [
                                      'Number of Locations\nwith Max Head',
                                      str(num_max_values)
                                  ]]

                    cell_color = [['white', 'white'], ['white', 'white'],
                                  ['white', 'white']]

                    tbl = table(ax,
                                cellText=table_vals,
                                cellColours=cell_color,
                                bbox=[0.5, 0.7, 0.40, 0.2],
                                rasterized=True,
                                zorder=10)

                    tbl.auto_set_font_size(False)
                    tbl.set_fontsize(8)

                    tbl.auto_set_column_width(col=[0, 1])

                    # plot locations of max heads above ground surface
                    ax.scatter(max_values.X,
                               max_values.Y,
                               s=50,
                               facecolors='none',
                               edgecolors='r')
Exemplo n.º 8
0
def plot_catalogue_performance(data_table,strategies,filename=None,figsize=(12,6),zbins=[(0.9,2.1),(2.1,None)],desi_nqso=[1.3*10**6,0.8*10**6],dv_max=6000.,show_correctwrongzbin=False,verbose=False,nydec=0,ymax=0.1,filter=None,add_bar_heights=True,extrarow=False,rotation=0.):

    fig, axs = plt.subplots(1,len(zbins),figsize=figsize,sharey=True,squeeze=False)

    if filter is None:
        filt = np.ones(len(data_table)).astype(bool)
    else:
        filt = filter

    # determine the true classifications
    isqso_truth, isgal_truth, isstar_truth, isbad = get_truths(data_table)

    for i,zbin in enumerate(zbins):

        for s in strategies.keys():

            # Make a filter to deal with masked arrays.
            filt_s = filt & (strategies[s]['isqso']|True)

            z_s = strategies[s]['z']
            w_s = strategies[s]['isqso']

            in_zbin_zvi = np.ones(data_table['Z_VI'].shape).astype(bool)
            in_zbin_zs = np.ones(z_s.shape).astype(bool)
            if zbin[0] is not None:
                in_zbin_zvi &= (data_table['Z_VI']>=zbin[0])
                in_zbin_zs &= (z_s>=zbin[0])
            if zbin[1] is not None:
                in_zbin_zvi &= (data_table['Z_VI']<zbin[1])
                in_zbin_zs &= (z_s<zbin[1])

            dv = strategy.get_dv(z_s,data_table['Z_VI'],data_table['Z_VI'],use_abs=True)
            zgood = (dv <= dv_max)

            strategies[s]['ncat'] = (w_s & in_zbin_zs & (~isbad) & filt_s).sum()

            strategies[s]['nstar'] = (w_s & in_zbin_zs & isstar_truth & filt_s).sum()
            strategies[s]['ngalwrongz'] = (w_s & ~zgood & in_zbin_zs & isgal_truth & filt_s).sum()
            strategies[s]['nqsowrongz'] = (w_s & ~zgood & in_zbin_zs & isqso_truth & filt_s).sum()
            strategies[s]['ncorrectwrongzbin'] = (w_s & zgood & in_zbin_zs & (isqso_truth | isgal_truth) & ~in_zbin_zvi & filt_s).sum()

            strategies[s]['nwrong'] = (strategies[s]['nstar'] + strategies[s]['ngalwrongz']
                                        + strategies[s]['nqsowrongz'] + strategies[s]['ncorrectwrongzbin'])

            com_num = (w_s & zgood & in_zbin_zs & isqso_truth & in_zbin_zvi & filt_s).sum()
            com_denom = (isqso_truth & in_zbin_zvi & filt_s).sum()
            strategies[s]['completeness'] = com_num/com_denom

            if ('RR' in s) and verbose:
                print(s)
                for k in strategies[s]:
                    if (k[0]=='n') or (k=='completeness'):
                        print(k,strategies[s][k])
                nlostqso = (~w_s & isqso_truth & in_zbin_zvi).sum()
                print(nlostqso,com_denom,nlostqso/com_denom)
                print('')

        nwrong = np.array([strategies[s]['nwrong'] for s in strategies.keys()])
        ncat = np.array([strategies[s]['ncat'] for s in strategies.keys()])

        pstar = np.array([strategies[s]['nstar'] for s in strategies.keys()])/ncat
        pgalwrongz = np.array([strategies[s]['ngalwrongz'] for s in strategies.keys()])/ncat
        pqsowrongz = np.array([strategies[s]['nqsowrongz'] for s in strategies.keys()])/ncat
        pcorrectwrongzbin = np.array([strategies[s]['ncorrectwrongzbin'] for s in strategies.keys()])/ncat

        completeness = np.array([strategies[s]['completeness'] for s in strategies.keys()])

        axs[0,i].bar(range(len(strategies)),pstar,color=utils.colours['C0'],label='star',width=0.5)
        axs[0,i].bar(range(len(strategies)),pgalwrongz,bottom=pstar,color=utils.colours['C1'],label='galaxy w.\nwrong $z$',width=0.5)
        bars = axs[0,i].bar(range(len(strategies)),pqsowrongz,bottom=pstar+pgalwrongz,color=utils.colours['C2'],label='QSO w.\nwrong $z$',width=0.5)
        if show_correctwrongzbin:
            axs[0,i].bar(range(len(strategies)),pcorrectwrongzbin,bottom=pstar+pgalwrongz+pqsowrongz,color=utils.colours['C3'],label='correct w.\nwrong $z$-bin',width=0.5)

        if add_bar_heights:
            bar_heights = pstar+pgalwrongz+pqsowrongz
            if show_correctwrongzbin:
                bar_heights += pcorrectwrongzbin
            utils.autolabel_bars(axs[0,i],bars,numbers=bar_heights,heights=bar_heights,percentage=True,above=True)

        DESI_ncat_presents = []
        for j,c in enumerate(completeness):
            pcon = nwrong[j]/ncat[j]
            DESI_ncat = c * desi_nqso[i]/(1 - pcon)
            DESI_ncat_present = (round(DESI_ncat * 10**-6,3))
            DESI_ncat_presents.append(DESI_ncat_present)

        axs[0,i].set_xlabel('classification strategy',labelpad=10)
        axs[0,i].set_xticks(range(len(strategies)))
        axs[0,i].set_xlim(-0.5,len(strategies)-0.5)
        slabels = []
        for s in strategies.keys():
            try:
                slabels += [strategies[s]['label']]
            except KeyError:
                slabels += [s]
        if rotation>0:
            ha = 'right'
        else:
            ha = 'center'
        axs[0,i].set_xticklabels(slabels,rotation=rotation, ha=ha, rotation_mode="anchor")

        axs[0,i].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0,decimals=nydec))
        axs[0,i].set_ylim(0,ymax)
        zbin_label = get_label_from_zbin(zbin)
        axs[0,i].text(0.5,1.05,zbin_label,ha='center',va='center',transform=axs[0,i].transAxes)

        cell_text = []
        for s in slabels:
            if '\n' in s:
                extrarow = True
        if extrarow:
            cell_text.append(['']*len(completeness))
        cell_text.append(['{:2.1%}'.format(c) for c in completeness])
        cell_text.append(['{:1.2f}'.format(c) for c in DESI_ncat_presents])

        rowLabels = []
        if extrarow:
            rowLabels += ['']
        if i==0:
            rowLabels += ['completeness:','estimated DESI\ncatalogue size\n[million QSOs]:']
        else:
            rowLabels += ['','']
        table = mp_table.table(cellText=cell_text,
                      rowLabels=rowLabels,
                      colLabels=['' for s in strategies.keys()],
                      loc='bottom',
                      ax=axs[0,i],
                      edges='open',
                      cellLoc='center',
                      rowLoc='right',
                      in_layout=True)
        table.scale(1,6)

    offset_label = -0.1
    axs[0,0].set_ylabel('contamination of\nQSO catalogue')
    axs[0,1].legend()

    rect = (0.07,0.2,1.0,1.0)
    plt.tight_layout(rect=rect)
    if filename is not None:
        plt.savefig(filename)

    return fig, axs
Exemplo n.º 9
0
 def __call__(self, data):
     t, y = data
     if self.model._last_visualized_ts is None or self.vis is None:
         self.vis = t
     elif self.vis < self.model._last_visualized_ts:
         # This is to avoid cases with repeated past measurements. Should
         # perhaps be handled in livehart.py
         self.vis = self.model._last_visualized_ts
     if self.start_day is None:
         self.start_day = t.day
     try:
         self.model.update(y)
     except ValueError as err:
         logging.error('ERROR: %s' % (err))
         return None
     # Update lines
     bufidx = self.model._buffer.index
     tmpdata = self.model._buffer.loc[bufidx > self.vis]
     day = tmpdata.index.day - self.start_day
     self.xdata.extend(
         (day * 3200 * 24 + tmpdata.index.hour * 3200 +
          tmpdata.index.minute * 60 + tmpdata.index.second).values.tolist())
     self.ydata.extend(tmpdata['active'].values.tolist())
     lim = min(len(self.xdata), self.time_window)
     self.line_active.set_data(self.xdata[-lim:], self.ydata[-lim:])
     lim1 = (bufidx[-1] - self.model._last_visualized_ts).seconds
     ydisp = self.model._yest[-(lim - lim1):].tolist()
     if lim1 > 0:
         self.line_est.set_data(self.xdata[-lim:-lim1], ydisp)
         ymatchdisp = self.model._ymatch['active'].values[
             -lim:-lim1].tolist()
         self.line_match.set_data(self.xdata[-lim:-lim1], ymatchdisp)
     else:
         self.line_est.set_data(self.xdata[-lim:], ydisp)
         ymatchdisp = self.model._ymatch['active'].values[-lim:].tolist()
         self.line_match.set_data(self.xdata[-lim:], ymatchdisp)
     # Update axis limits
     tidx = t.hour * 3200 + t.minute * 60 + t.second
     xmin, xmax = self.ax.get_xlim()
     xmin = max(0, tidx + self.step - self.time_window + 1)
     xmax = max(self.time_window, tidx + self.step)
     ymin = min(self.ydata[-self.time_window:] + [0])  # List concatenation
     ymax = max(self.ydata[-self.time_window:] + [0])  # List concatenation
     self.ax.set_xlim(xmin - 100, xmax + 100)
     self.ax.set_ylim(ymin - 50, ymax + 100)
     self.ax.figure.canvas.draw()
     # Add table
     if not self.model.live:
         cell_text = [['None', '-', '-']]
     else:
         cell_text = [[m.name, m.signature[0, 0], m.signature[0, 1]]
                      for m in self.model.live]
     cell_text.append(['Other', self.model.residual_live[0], '-'])
     cell_text.append(['Background', self.model.background_active, '-'])
     tab = table(self.axt,
                 cell_text,
                 colLabels=['Appliance', 'Active', 'Reactive'],
                 cellLoc='left',
                 colLoc='left',
                 edges='horizontal')
     for (row, col), cell in tab.get_celld().items():
         if (row == 0) or (col == -1):
             cell.set_text_props(fontproperties=FontProperties(
                 weight='bold'))
     self.axt.clear()
     self.axt.add_table(tab)
     self.axt.set_axis_off()
     self.axt.figure.canvas.draw()
     # TODO (for dates)
     # self.xdata.extend(y.index.strftime('%Y-%m-%d %H:%M:%S').tolist())
     # Save model
     if self.model_path_w and \
        (self.save_counter % self.MODEL_SAVE_STEP == 0):
         with open(self.model_path_w, "wb") as fp:
             dill.dump(self.model, fp)
     self.save_counter += 1
     return self.line_active, \
         self.line_est, \
         self.line_match