Exemplo n.º 1
0
def cum_subrecruits(ax, recruits, first_plot_day, color, linestyle):
    """Plot the cumulative sum of recruits to compare across many populations.

    Args:
        ax: A series of axis instances to plot on.
        recruits: An xr.DataArray that representing the expected or
            observed recruits. Must have a time dimension.
        first_plot_day: An int representing the first date to plot.
        color: A mpl color.
        linestyle: A mpl linestyle.
    """
    sel_recruits = plot_utils.unpack_participant_labels(recruits)
    labels_to_plot = plot_utils.get_labels_to_plot(recruits)
    num_plots = len(ax)
    for i, label in enumerate(labels_to_plot):
        a = ax[i]
        participants = sel_recruits.sel(participant_label=label, drop=True)
        a.set_title(label)
        time_dim = plot_utils.find_time_dim(participants)
        array_over_time(a, participants.cumsum(time_dim), first_plot_day,
                        {'color': color, 'ls': linestyle})

        if i in [num_plots-2, num_plots-1]:
            format_time_axis(a, 3, date_format='%b-%d')
        else:
            format_time_axis(a, 3, include_labels=False)
Exemplo n.º 2
0
def sort_table(ds_first, da_participants, ds_last, sort_col):
    """Sort the table and return as a pd.DataFrame indexed by location.

    We return the sorted table as a pd.DataFrame, as the pd display interface
    works well with ipywidgets.

    Args:
        ds_first: An ordered xr.DataSet, whose data_vars have dims (location,),
            to display on the left of the table.
        da_participants: An ordered xr.DataArray with dims (location, *participant_dims)
            representing some measure of the fraction of participants with different labels.
            Canonical examples are `participant_fraction` or `population_fraction`.
        ds_last: An ordered xr.DataSet, whose data_vars have dims (location,),
            to display on the right of the table.
        sort_col: A string representing the column name to sort the rows by.

    Returns:
        col_sorted_df: A pd.DataFrame with ordered columns
    """
    p_fraction = plot_utils.unpack_participant_labels(da_participants)
    p_fraction.name = p_fraction.p.item()
    del p_fraction['p']
    p_df = p_fraction.to_dataframe()
    pivot_p_df = p_df.reset_index('participant_label').pivot(columns='participant_label')
    pivot_p_df.columns = pivot_p_df.columns.droplevel()

    labels_to_keep = plot_utils.get_labels_to_plot(da_participants)
    if sort_col in labels_to_keep:
        # move to front
        labels_to_keep.remove(sort_col)
        labels_to_keep.insert(0, sort_col)
    sorted_p_df = pivot_p_df[labels_to_keep]

    # Other variables that don't depend on participant label
    first_df = ds_first.to_dataframe()

    concat_df = pd.concat((first_df, sorted_p_df), axis=1)

    if ds_last is not None:
        last_df = ds_last.to_dataframe()
        concat_df = pd.concat((concat_df, last_df), axis=1)

    sort_df = concat_df.sort_values(by=sort_col, ascending=False, na_position='last')
    return sort_df.round(decimals=2)
Exemplo n.º 3
0
def disp_table(ds,
               box,
               label_to_sort_by,
               num_rows=10,
               var_to_disp='population_fraction'):
    """Make the table, sort by label_to_sort_by, display in box.

    Makes a table displaying population fraction, population, current
    site_activation, proposed_events, and original_events for all locations
    in our ville.

    Args:
        ds: An xr.Dataset containing the ville to visualize.
        box: An ipywidgets.Box containing two outputs.
        label_to_sort_by: A string representing the data column to sort by.
        num_rows: An int representing the number of rows to display
        var_to_disp: A string representing the ds.data_var to show. Must
            have dims (location, *participant_dims). Canonical examples are
            'participant_fraction' and 'population_fraction'.
    """
    ds['frac_cap'] = ds['site_activation'].isel(time=0)
    ds_first = ds[[
        item for item in ['SR1', 'SR2', 'frac_cap', 'population']
        if item in ds.data_vars
    ]]
    ds_last = ds[['proposed_events', 'original_events']]

    int_utils.update_disp(box.children[0], var_to_disp)
    da_participants = ds[var_to_disp]
    table = table_utils.sort_table(ds_first, da_participants, ds_last,
                                   label_to_sort_by)
    styled_table = table[:num_rows].style.format({
        'population': '{:,}',
        'frac_cap': '{:.1f}x',
        'proposed_events': '{:.2f}',
        'original_events': '{:.2f}'
    }).format(lambda x: '{:2.0f}%'.format(100 * x),
              subset=plot_utils.get_labels_to_plot(da_participants))
    int_utils.update_disp(box.children[1], styled_table)
    return
Exemplo n.º 4
0
def make_table_button(ds, box, status_button):
    """Make a widget to interactively sort the table.

    Args:
        ds: An xr.Dataset containing the ville to visualize.
        box: An ipywidgets.Box containing two outputs.
        status_button: An ipywidgets.Button to indicate when code is running
    Return:
        table_dropdown: An ipywidgets.Dropdown to sort the table.
    """
    def update_table_by_button(ds, box, status_button, label_button):
        int_utils.set_status(status_button, 'Not_Ready')
        label_to_sort_by = label_button['new']
        disp_table(ds, box, label_to_sort_by)
        int_utils.set_status(status_button, 'Ready')

    partial_disp = functools.partial(update_table_by_button, ds, box,
                                     status_button)
    label_opts = plot_utils.get_labels_to_plot(ds.participants)
    label_opts += ['population', 'proposed_events', 'original_events']
    table_dropdown = int_utils.new_dropdown(label_opts, 'Sort by:')
    table_dropdown.observe(partial_disp, type='change', names='value')
    return table_dropdown
Exemplo n.º 5
0
def loc_plots(ds, box, loc_to_plot=None):
    """Make plots summarizing incidence and recruitment in one location.

    Show data specific to one location. Plot the incidence, cumulative
    recruits, and recuits in each subgroup as functions of time.

    Args:
        ds: An xr.dataset containing the vill' to visualize.
        box: An ipywidgets.Box containing three outputs.
        loc_to_plot: A ds.location.coord specifiying the location to plot
    """
    # setup
    if loc_to_plot is None:
        loc_to_plot = ds.location.values[0]

    fpd = ville_config.FIRST_PLOT_DAY

    pc = colors_config.ville_styles['gray_ville_3']['color']
    oc = colors_config.ville_styles['highlight_ville_2']['color']
    ls = '-'

    # select just what we want to look at
    incidence = ds.incidence_flattened.sel(location=loc_to_plot)
    historical_incidence = ds.historical_incidence.sel(location=loc_to_plot)

    proposed_part = plot_utils.join_in_time(
        ds, 'participants').sel(location=loc_to_plot)
    original_part = plot_utils.join_in_time(
        ds, 'original_participants').sel(location=loc_to_plot)

    # Make plots

    # Incidence
    # No original as the user cannot interactively control the incidence.
    fig, axis = plot_utils.new_figure()

    plot.incidence(axis, incidence, fpd, 'k', '-')
    plot.incidence(axis, historical_incidence, fpd, 'k', '-')
    plot.format_time_axis(axis, date_format='%b-%d')
    # TODO find a better way to align plots
    axis.text(-0.25,
              1.1,
              f'Individual Trial Site \n {loc_to_plot}',
              ha='left',
              va='bottom',
              transform=axis.transAxes,
              fontsize=16.)
    int_utils.update_disp(box.children[0], fig)

    # Total recruits over time
    fig, axis = plot_utils.new_figure()
    # TODO figure out grids
    axis.text(0.5,
              1.1,
              ' ',
              ha='left',
              va='bottom',
              transform=axis.transAxes,
              fontsize=16.)
    # Assume everything but time is a participant dimension.
    # The arrays aren't unpacked, so this is not equivalent to
    # get_labels_to_plot.
    p_dims = list(proposed_part.dims)
    remove_dims = [plot_utils.find_time_dim(proposed_part)]
    for rd in remove_dims:
        p_dims.remove(rd)

    plot.cum_recruits(axis, original_part.sum(p_dims), fpd, oc, ls)
    plot.cum_recruits(axis, proposed_part.sum(p_dims), fpd, pc, ls)
    plot.format_time_axis(axis, date_format='%b-%d')
    axis.set_title('Cumulative Recruits \n All Participants')

    int_utils.update_disp(box.children[1], fig)

    # Subrecruits over time
    num_labels = len(plot_utils.get_labels_to_plot(proposed_part))
    num_cols = 2
    num_rows = num_labels // num_cols + (num_labels % num_cols != 0)
    fig, a = plot_utils.make_subplots(num_rows, num_cols)
    fig.suptitle('Cumulative Recruits')
    plot.cum_subrecruits(a, original_part, fpd, oc, ls)
    plot.cum_subrecruits(a, proposed_part, fpd, pc, ls)

    int_utils.update_disp(box.children[2], fig)
Exemplo n.º 6
0
def summary_plots(ds, box, efficacies=(0.55, 0.75)):
    """Make plots summarizing total recruitment and time to sucess.

    These plots sum information across all the ville locations. Plots the total
    number of recruits in each subgroup and the expected time to success for
    assumed vaccine efficacies.

    Args:
        ds: An xr.Dataset containing the ville to visualize.
        box: An ipywidgets.Box containing two outputs.
        efficacies: A tuple of floats representing the various assumed, vaccine
            efficacies as a percent.
    """
    # TODO add hist_events and hist_recruits

    pc = colors_config.ville_styles['gray_ville_3']['color']
    oc = colors_config.ville_styles['highlight_ville_2']['color']
    ls = '-'

    # Difference in recruits by participant label
    fig, a = plot_utils.make_subplots(1, 2, (5, 5.5), sharex=False)
    fig.suptitle('Recruitment', fontsize=16.)

    # TODO: investigate grids
    a[0].text(-0.55,
              1.25,
              'Trial Simulation - All Sites',
              ha='left',
              va='bottom',
              transform=a[0].transAxes,
              fontsize=20.)
    a[0].text(-0.55,
              1.24,
              'Compares original and proposed trial simulations',
              ha='left',
              va='top',
              transform=a[0].transAxes,
              fontsize=12.)

    joined_part = plot_utils.join_in_time(ds, 'participants')
    original_joined_part = plot_utils.join_in_time(ds, 'original_participants')
    proposed_unpack = plot_utils.unpack_participant_labels(joined_part)
    original_unpack = plot_utils.unpack_participant_labels(
        original_joined_part)
    labels_to_plot = plot_utils.get_labels_to_plot(joined_part)

    plot.recruit_diffs('participant_label', a[0],
                       proposed_unpack.sel(participant_label=labels_to_plot),
                       original_unpack.sel(participant_label=labels_to_plot),
                       True)
    # for visibiliy, turn spines off
    plot.turn_spines_off(a[0])
    # To draw attention to xlabels, make them a little bigger
    a[0].tick_params(axis='x', labelsize=12.5)
    a[0].set_title('Proposed - original recruits')

    # Total recruits
    plot.recruits('participant_label',
                  a[1],
                  original_unpack.sel(participant_label=labels_to_plot),
                  oc,
                  ls,
                  label='original')
    plot.recruits('participant_label',
                  a[1],
                  proposed_unpack.sel(participant_label=labels_to_plot),
                  pc,
                  ls,
                  label='proposed')
    # for visibiliy, turn spines off
    a[1].set_title('Total participants')
    fig.legend(bbox_to_anchor=(1.06, 1.125))

    int_utils.update_disp(box.children[0], fig)

    # Time to success
    num_cols = len(efficacies)
    num_rows = 2
    num_plots = num_cols * num_rows
    fig, a = plot_utils.make_subplots(num_rows,
                                      num_cols, (10, 6.0),
                                      sharey='row')
    a[0].text(0.0,
              1.25,
              'Success day probability distribution',
              horizontalalignment='left',
              transform=a[0].transAxes,
              fontsize=16.)
    a[num_rows].text(0.0,
                     1.0,
                     'Proposed - original success day',
                     ha='left',
                     va='bottom',
                     transform=a[num_rows].transAxes,
                     fontsize=16.)

    for i, efficacy in enumerate(efficacies):
        ax = a[i]
        # tts
        plot.tts(ax, ds.control_arm_events, efficacy, pc, ls)
        plot.tts(ax, ds.original_control_arm_events, efficacy, oc, ls)
        ax.set_title(f'{efficacy} Efficacy')
        ax.xaxis.set_tick_params(which='both', labelbottom=True)

        ax = a[i + num_cols]
        plot.turn_spines_off(ax)
        ax.tick_params(axis='y', labelsize=12.5)
        plot.tts_diff(ax, ds.control_arm_events,
                      ds.original_control_arm_events, efficacy)

    int_utils.update_disp(box.children[1], fig)