예제 #1
0
def test_plotting():

    torun = [
        'hex2rgb',
        'arraycolors',
        'gridcolors',
        'surf3d',
        'bar3d',
    ]

    if 'hex2rgb' in torun:
        c1 = sc.hex2rgb('#fff')
        c2 = sc.hex2rgb('fabcd8')
        print(c1)
        print(c2)

    if 'arraycolors' in torun:
        n = 1000
        ncols = 5
        arr = pl.rand(n, ncols)
        for c in range(ncols):
            arr[:, c] += c
        x = pl.rand(n)
        y = pl.rand(n)
        colors = sc.arraycolors(arr)
        if doplot:
            pl.figure(figsize=(20, 16))
            for c in range(ncols):
                pl.scatter(x + c, y, s=50, c=colors[:, c])

    if 'gridcolors' in torun:
        colors_a = sc.gridcolors(ncolors=8, demo=doplot)
        colors_b = sc.gridcolors(ncolors=18, demo=doplot)
        colors_c = sc.gridcolors(ncolors=28, demo=doplot)
        print('\n8 colors:', colors_a)
        print('\n18 colors:', colors_b)
        print('\n28 colors:', colors_c)

    if 'surf3d' in torun:
        data = pl.randn(50, 50)
        smoothdata = sc.smooth(data, 20)
        if doplot:
            sc.surf3d(smoothdata)

    if 'bar3d' in torun:
        data = pl.rand(20, 20)
        smoothdata = sc.smooth(data)
        if doplot:
            sc.bar3d(smoothdata)

    return pl.gcf()
예제 #2
0
def plot_scens(to_plot=None, scens=None, do_save=None, fig_path=None, fig_args=None, plot_args=None,
         scatter_args=None, axis_args=None, fill_args=None, legend_args=None, date_args=None,
         show_args=None, mpl_args=None, n_cols=None, grid=False, commaticks=True, setylim=True,
         log_scale=False, colors=None, labels=None, do_show=None, sep_figs=False, fig=None, ax=None, **kwargs):
    ''' Plot the results of a scenario -- see Scenarios.plot() for documentation '''

    # Handle inputs
    args = handle_args(fig_args=fig_args, plot_args=plot_args, scatter_args=scatter_args, axis_args=axis_args, fill_args=fill_args,
                   legend_args=legend_args, show_args=show_args, date_args=date_args, mpl_args=mpl_args, **kwargs)
    to_plot, n_cols, n_rows = handle_to_plot('scens', to_plot, n_cols, sim=scens.base_sim, check_ready=False) # Since this sim isn't run
    fig, figs = create_figs(args, sep_figs, fig, ax)

    # Do the plotting
    default_colors = sc.gridcolors(ncolors=len(scens.sims))
    for pnum,title,reskeys in to_plot.enumitems():
        ax = create_subplots(figs, fig, ax, n_rows, n_cols, pnum, args.fig, sep_figs, log_scale, title)
        reskeys = sc.promotetolist(reskeys) # In case it's a string
        for reskey in reskeys:
            resdata = scens.results[reskey]
            for snum,scenkey,scendata in resdata.enumitems():
                sim = scens.sims[scenkey][0] # Pull out the first sim in the list for this scenario
                strain_keys = sim.result_keys('strain')
                if reskey in strain_keys:
                    ns = sim['n_strains']
                    strain_colors = sc.gridcolors(ns)
                    for strain in range(ns):
                        res_y = scendata.best[strain,:]
                        color = strain_colors[strain]  # Choose the color
                        label = 'wild type' if strain == 0 else sim['strains'][strain - 1].label
                        ax.fill_between(scens.tvec, scendata.low[strain,:], scendata.high[strain,:], color=color, **args.fill)  # Create the uncertainty bound
                        ax.plot(scens.tvec, res_y, label=label, c=color, **args.plot)  # Plot the actual line
                        if args.show['data']:
                            plot_data(sim, ax, reskey, args.scatter, color=color)  # Plot the data
                else:
                    res_y = scendata.best
                    color = set_line_options(colors, scenkey, snum, default_colors[snum])  # Choose the color
                    label = set_line_options(labels, scenkey, snum, scendata.name)  # Choose the label
                    ax.fill_between(scens.tvec, scendata.low, scendata.high, color=color, **args.fill)  # Create the uncertainty bound
                    ax.plot(scens.tvec, res_y, label=label, c=color, **args.plot)  # Plot the actual line
                    if args.show['data']:
                        plot_data(sim, ax, reskey, args.scatter, color=color)  # Plot the data

                if args.show['interventions']:
                    plot_interventions(sim, ax) # Plot the interventions
                if args.show['ticks']:
                    reset_ticks(ax, sim, args.date) # Optionally reset tick marks (useful for e.g. plotting weeks/months)
        if args.show['legend']:
            title_grid_legend(ax, title, grid, commaticks, setylim, args.legend, pnum==0) # Configure the title, grid, and legend -- only show legend for first

    return tidy_up(fig, figs, sep_figs, do_save, fig_path, do_show, args)
예제 #3
0
    def plot_cascade(self, vertical=True):
        if vertical:
            fig_size = (12, 12)
            ax_size = [0.45, 0.05, 0.5, 0.9]
        else:
            fig_size = (16, 8)
            ax_size = [0.05, 0.45, 0.9, 0.5]
        df = sc.dcp(self.data)
        cutoff = 200e3
        fig = pl.figure(figsize=fig_size)
        df.sort(col='icer', reverse=False)
        DA_data = hp.arr(df['opt_spend'])
        inds = sc.findinds(DA_data > cutoff)
        DA_data = DA_data[inds]
        DA_data /= 1e6
        DA_labels = df['shortname'][inds]
        npts = len(DA_data)
        colors = sc.gridcolors(npts, limits=(0.25, 0.75))
        x = np.arange(len(DA_data))
        pl.axes(ax_size)
        for pt in range(npts):
            loc = x[pt:]
            this = DA_data[pt]
            start = sum(DA_data[:pt])
            prop = 0.9
            color = colors[pt]
            amount = sum(DA_data[:pt + 1])
            amountstr = '%0.1f' % amount
            if vertical:
                pl.barh(loc, width=this, left=start, height=prop, color=color)
                pl.text(amount,
                        x[pt],
                        amountstr,
                        verticalalignment='center',
                        color=colors[pt])
            else:
                pl.bar(loc, height=this, bottom=start, width=prop, color=color)
                pl.text(x[pt],
                        amount + 1,
                        amountstr,
                        horizontalalignment='center',
                        color=colors[pt])
        if vertical:
            pl.xlabel('Spending for optimized investment cascade')
            pl.gca().set_yticks(x)
            ticklabels = pl.gca().set_yticklabels(DA_labels)
        else:
            pl.ylabel('Optimized investment cascade')
            pl.gca().set_xticks(x)
            ticklabels = pl.gca().set_xticklabels(DA_labels, rotation=90)
        for t, tl in enumerate(ticklabels):
            tl.set_color(colors[t])

        pl.gca().set_facecolor('none')
        pl.title('Investment cascade')
        return fig
예제 #4
0
def plot_sim(to_plot=None, sim=None, do_save=None, fig_path=None, fig_args=None, plot_args=None,
         scatter_args=None, axis_args=None, fill_args=None, legend_args=None, date_args=None,
         show_args=None, mpl_args=None, n_cols=None, grid=False, commaticks=True,
         setylim=True, log_scale=False, colors=None, labels=None, do_show=None, sep_figs=False,
         fig=None, ax=None, **kwargs):
    ''' Plot the results of a single simulation -- see Sim.plot() for documentation '''

    # Handle inputs
    args = handle_args(fig_args=fig_args, plot_args=plot_args, scatter_args=scatter_args, axis_args=axis_args, fill_args=fill_args,
                       legend_args=legend_args, show_args=show_args, date_args=date_args, mpl_args=mpl_args, **kwargs)
    to_plot, n_cols, n_rows = handle_to_plot('sim', to_plot, n_cols, sim=sim)
    fig, figs = create_figs(args, sep_figs, fig, ax)

    # Do the plotting
    strain_keys = sim.result_keys('strain')
    for pnum,title,keylabels in to_plot.enumitems():
        ax = create_subplots(figs, fig, ax, n_rows, n_cols, pnum, args.fig, sep_figs, log_scale, title)
        for resnum,reskey in enumerate(keylabels):
            res_t = sim.results['t']
            if reskey in strain_keys:
                res = sim.results['strain'][reskey]
                ns = sim['n_strains']
                strain_colors = sc.gridcolors(ns)
                for strain in range(ns):
                    color = strain_colors[strain]  # Choose the color
                    label = 'wild type' if strain == 0 else sim['strains'][strain-1].label
                    if res.low is not None and res.high is not None:
                        ax.fill_between(res_t, res.low[strain,:], res.high[strain,:], color=color, **args.fill)  # Create the uncertainty bound
                    ax.plot(res_t, res.values[strain,:], label=label, **args.plot, c=color)  # Actually plot the sim!
            else:
                res = sim.results[reskey]
                color = set_line_options(colors, reskey, resnum, res.color)  # Choose the color
                label = set_line_options(labels, reskey, resnum, res.name)  # Choose the label
                if res.low is not None and res.high is not None:
                    ax.fill_between(res_t, res.low, res.high, color=color, **args.fill)  # Create the uncertainty bound
                ax.plot(res_t, res.values, label=label, **args.plot, c=color)  # Actually plot the sim!
            if args.show['data']:
                plot_data(sim, ax, reskey, args.scatter, color=color)  # Plot the data
            if args.show['ticks']:
                reset_ticks(ax, sim, args.date) # Optionally reset tick marks (useful for e.g. plotting weeks/months)
        if args.show['interventions']:
            plot_interventions(sim, ax) # Plot the interventions
        if args.show['legend']:
            title_grid_legend(ax, title, grid, commaticks, setylim, args.legend) # Configure the title, grid, and legend

    return tidy_up(fig, figs, sep_figs, do_save, fig_path, do_show, args)
예제 #5
0
 def plot_result(self, key, colors=None, labels=None, *args, **kwargs):
     ''' Convenience method for plotting -- arguments passed to Sim.plot_result() '''
     if self.which in ['combined', 'reduced']:
         fig = self.base_sim.plot_result(key, *args, **kwargs)
     else:
         fig = None
         if colors is None:
             colors = sc.gridcolors(len(self))
         if labels is None:
             labels = [sim.label for sim in self.sims]
         orig_setylim = kwargs.get('setylim', True)
         for s,sim in enumerate(self.sims):
             if s == len(self.sims)-1:
                 kwargs['setylim'] = orig_setylim
             else:
                 kwargs['setylim'] = False
             fig = sim.plot_result(key=key, fig=fig, color=colors[s], label=labels[s], *args, **kwargs)
     return fig
예제 #6
0
def plot_scens(scens, data_check=None, dday=None, test_data=None, to_plot=None, do_save=None, fig_path=None, fig_args=None, plot_args=None,
         scatter_args=None, axis_args=None, fill_args=None, legend_args=None, show_args=None,
         as_dates=True, dateformat=None, interval=None, n_cols=None, font_size=18, font_family=None,
         grid=False, commaticks=True, setylim=True, log_scale=False, colors=None, labels=None,
         do_show=True, sep_figs=False, fig=None):
    ''' Plot the results of a scenario -- see Scenarios.plot() for documentation '''

    # Handle inputs
    args = handle_args(fig_args, plot_args, scatter_args, axis_args, fill_args, legend_args)
    to_plot, n_cols, n_rows = handle_to_plot('scens', to_plot, n_cols, sim=scens.base_sim)
    fig, figs, ax = create_figs(args, font_size, font_family, sep_figs, fig)
    was_plot=False
    # Do the plotting
    default_colors = sc.gridcolors(ncolors=len(scens.sims))
    for pnum,title,reskeys in to_plot.enumitems():
        ax = create_subplots(figs, fig, ax, n_rows, n_cols, pnum, args.fig, sep_figs, log_scale, title)
        reskeys = sc.promotetolist(reskeys) # In case it's a string
        for reskey in reskeys:
            resdata = scens.results[reskey]
            for snum,scenkey,scendata in resdata.enumitems():
                sim = scens.sims[scenkey][0] # Pull out the first sim in the list for this scenario
                res_y = scendata.best
                color = set_line_options(colors, scenkey, snum, default_colors[snum]) # Choose the color
                label = set_line_options(labels, scenkey, snum, scendata.name) # Choose the label
                if dday is not None:
                    ax.fill_between(scens.tvec[-dday:], smooth(scendata.low)[-dday:], smooth(scendata.high)[-dday:], color=color, **args.fill) # Create the uncertainty bound
                    ax.plot(scens.tvec[-dday:], smooth(res_y)[-dday:], label=label, c=color, **args.plot) # Plot the actual line
                else:
                    ax.fill_between(scens.tvec, smooth(scendata.low), smooth(scendata.high), color=color, **args.fill) # Create the uncertainty bound
                    ax.plot(scens.tvec, smooth(res_y), label=label, c=color, **args.plot)
                if data_check is not None and not was_plot:
                    ax.scatter(range(data_check.index.size), data_check[reskey], label='Data')
                    was_plot=True
                elif data_check is None:
                    if args.show['data']:
                        plot_data(sim, ax, reskey, args.scatter, dday=dday, test_data=test_data, color=color) # Plot the data
                if args.show['interventions']:
                    plot_interventions(sim, ax) # Plot the interventions
                if args.show['ticks']:
                    reset_ticks(ax, sim, interval, as_dates, dateformat) # Optionally reset tick marks (useful for e.g. plotting weeks/months)
        if args.show['legend']:
            title_grid_legend(ax, title, grid, commaticks, setylim, args.legend, pnum==0) # Configure the title, grid, and legend -- only show legend for first

    return tidy_up(fig, figs, sep_figs, do_save, fig_path, do_show)
예제 #7
0
def plot_scens(scens, to_plot=None, do_save=None, fig_path=None, fig_args=None, plot_args=None,
         scatter_args=None, axis_args=None, fill_args=None, legend_args=None, as_dates=True, dateformat=None,
         interval=None, n_cols=1, font_size=18, font_family=None, grid=False, commaticks=True, setylim=True,
         log_scale=False, colors=None, labels=None, do_show=True, sep_figs=False, fig=None):
    ''' Plot the results of a scenario -- see Scenarios.plot() for documentation '''

    # Handle inputs
    args = handle_args(fig_args, plot_args, scatter_args, axis_args, fill_args, legend_args)
    to_plot, n_rows = handle_to_plot('scens', to_plot, n_cols, sim=scens.base_sim)
    fig, figs, ax = create_figs(args, font_size, font_family, sep_figs, fig)

    # Do the plotting
    default_colors = sc.gridcolors(ncolors=len(scens.sims))
    for pnum,title,reskeys in to_plot.enumitems():
        ax = create_subplots(figs, fig, ax, n_rows, n_cols, pnum, args.fig, sep_figs, log_scale, title)
        reskeys = sc.promotetolist(reskeys) # In case it's a string
        for reskey in reskeys:
            resdata = scens.results[reskey]
            for snum,scenkey,scendata in resdata.enumitems():
                sim = scens.sims[scenkey][0] # Pull out the first sim in the list for this scenario
                res_y = scendata.best
                if colors is not None:
                    color = colors[scenkey]
                else:
                    color = default_colors[snum]
                if labels is not None:
                    label = labels[scenkey]
                else:
                    label = scendata.name
                ax.fill_between(scens.tvec, scendata.low, scendata.high, color=color, **args.fill) # Create the uncertainty bound
                ax.plot(scens.tvec, res_y, label=label, c=color, **args.plot) # Plot the actual line
                plot_data(sim, reskey, args.scatter) # Plot the data
                plot_interventions(sim, ax) # Plot the interventions
                reset_ticks(ax, sim, interval, as_dates) # Optionally reset tick marks (useful for e.g. plotting weeks/months)
        title_grid_legend(ax, title, grid, commaticks, setylim, args.legend, pnum==0) # Configure the title, grid, and legend -- only show legend for first

    return tidy_up(fig, figs, sep_figs, do_save, fig_path, do_show, default_name='covasim_scenarios.png')
예제 #8
0
    def makepackage(self,
                    burdenset=None,
                    intervset=None,
                    frpwt=None,
                    equitywt=None,
                    verbose=True,
                    die=False):
        ''' Make results '''
        # Handle inputs
        if burdenset is not None:
            self.burdenset = burdenset  # Warning, name is used both as key and actual set!
        if intervset is not None: self.intervset = intervset
        if frpwt is None: frpwt = 0.25
        if equitywt is None: equitywt = 0.25
        self.frpwt = frpwt
        self.equitywt = equitywt
        burdenset = self.projectref().burden(key=self.burdenset)
        intervset = self.projectref().interv(key=self.intervset)
        intervset.parse()  # Ensure it's parsed
        colnames = intervset.colnames

        # Create new dataframe
        origdata = sc.dcp(intervset.data)
        critical_cols = [
            'active', 'shortname', 'unitcost', 'spend', 'icer', 'frp', 'equity'
        ]
        df = sc.dataframe()
        for col in critical_cols:  # Copy columns over
            df[col] = sc.dcp(origdata[colnames[col]])
        df['parsedbc'] = sc.dcp(origdata['parsedbc'])  # Since not named
        df.filter_out(key=0, col='active', verbose=True)

        # Calculate people covered (spending/unitcost)
        df['coverage'] = hp.arr(
            df['spend']) / (self.eps + hp.arr(df['unitcost']))

        # Pull out DALYS and prevalence
        df.addcol('total_dalys',
                  value=0)  # Value=0 by default, but just to be explicit
        df.addcol('max_dalys', value=0)
        df.addcol('total_prevalence', value=0)
        df.addcol('dalys_averted', value=0)
        notfound = []
        lasterror = None
        for r in range(df.nrows):
            theseburdencovs = df['parsedbc', r]
            for burdencov in theseburdencovs:
                key = burdencov[0]
                val = burdencov[1]  # WARNING, add validation here
                try:
                    thisburden = burdenset.data.findrow(
                        key=key,
                        col=burdenset.colnames['cause'],
                        asdict=True,
                        die=True)
                    df['total_dalys',
                       r] += thisburden[burdenset.colnames['dalys']]
                    df['max_dalys',
                       r] += thisburden[burdenset.colnames['dalys']] * val
                    df['total_prevalence',
                       r] += thisburden[burdenset.colnames['prevalence']]
                except Exception as E:
                    lasterror = E  # Stupid Python 3
                    print('HIIII %s' % str(E))
                    print(type(df['total_dalys', r]))
                    print(type(df['max_dalys', r]))
                    print(type(df['total_prevalence', r]))
                    print(type(thisburden[burdenset.colnames['dalys']]))
                    print(type(thisburden[burdenset.colnames['prevalence']]))
                    notfound.append(key)

        # Validation
        if len(notfound):
            errormsg = 'The following burden(s) were not found: "%s"\nError:\n%s' % (
                notfound, str(lasterror))
            raise hp.HPException(errormsg)
        invalid = []
        for r in range(df.nrows):
            df['dalys_averted',
               r] = df['spend', r] / (self.eps + df['icer', r])
            if df['dalys_averted', r] > df['max_dalys', r]:
                errormsg = 'Data input error: DALYs averted for "%s" greater than total DALYs (%0.0f vs. %0.0f); please reduce total spending, increase ICER, increase DALYs, or increase max coverage' % (
                    df['shortname', r], df['dalys_averted', r], df['max_dalys',
                                                                   r])
                df['dalys_averted', r] = df[
                    'max_dalys',
                    r]  # WARNING, reset to maximum rather than give error if die=False
                invalid.append(errormsg)
        if len(invalid):
            errors = '\n\n'.join(invalid)
            if die: raise Exception(errors)
            else: print(errors)

        # To populate with optimization results and fixed spending
        self.budget = hp.arr(df['spend']).sum()
        df.addcol('opt_spend')
        df.addcol('opt_dalys_averted')
        df.addcol('fixed')

        # Store colors
        nintervs = df.nrows
        colors = sc.gridcolors(nintervs + 2,
                               asarray=True)[2:]  # Skip black and white
        colordict = sc.odict()
        for c, name in enumerate(df['shortname']):
            colordict[name] = colors[c]
        self.colordict = colordict

        self.data = df  # Store it
        if verbose:
            print(
                'Health package %s recalculated from burdenset=%s and intervset=%s'
                % (self.name, self.burdenset, self.intervset))
        return None
예제 #9
0
    def plot(self,
             to_plot=None,
             dday=None,
             inds=None,
             plot_sims=False,
             color_by_sim=None,
             max_sims=5,
             colors=None,
             labels=None,
             alpha_range=None,
             plot_args=None,
             show_args=None,
             **kwargs):
        '''
        Plot all the sims  -- arguments passed to Sim.plot(). The
        behavior depends on whether or not combine() or reduce() has been called.
        If so, this function by default plots only the combined/reduced sim (which
        you can override with plot_sims=True). Otherwise, it plots a separate line
        for each sim.

        Note that this function is complex because it aims to capture the flexibility
        of both sim.plot() and scens.plot(). By default, if combine() or reduce()
        has been used, it will resemble sim.plot(); otherwise, it will resemble
        scens.plot(). This can be changed via color_by_sim, together with the
        other options.

        Args:
            to_plot      (list) : list or dict of which results to plot; see cv.get_sim_plots() for structure
            inds         (list) : if not combined or reduced, the indices of the simulations to plot (if None, plot all)
            plot_sims    (bool) : whether to plot individual sims, even if combine() or reduce() has been used
            color_by_sim (bool) : if True, set colors based on the simulation type; otherwise, color by result type; True implies a scenario-style plotting, False implies sim-style plotting
            max_sims     (int)  : maximum number of sims to use with color-by-sim; can be overriden by other options
            colors       (list) : if supplied, override default colors for color_by_sim
            labels       (list) : if supplied, override default labels for color_by_sim
            alpha_range  (list) : a 2-element list/tuple/array providing the range of alpha values to use to distinguish the lines
            plot_args    (dict) : passed to sim.plot()
            show_args    (dict) : passed to sim.plot()
            kwargs       (dict) : passed to sim.plot()

        **Examples**::

            sim = cv.Sim()
            msim = cv.MultiSim(sim)
            msim.run()
            msim.plot() # Plots individual sims
            msim.reduce()
            msim.plot() # Plots the combined sim
        '''

        # Plot a single curve, possibly with a range
        if not plot_sims and self.which in ['combined', 'reduced']:
            fig = self.base_sim.plot(to_plot=to_plot,
                                     dday=dday,
                                     colors=colors,
                                     **kwargs)

        # PLot individual sims on top of each other
        else:

            # Initialize
            fig = kwargs.pop('fig', None)
            orig_show = kwargs.get('do_show', True)
            orig_setylim = kwargs.get('setylim', True)
            kwargs['legend_args'] = sc.mergedicts(
                {'show_legend': True}, kwargs.get(
                    'legend_args'))  # Only plot the legend the first time

            # Handle indices
            if inds is None:
                inds = np.arange(len(self.sims))
            n_sims = len(inds)

            # Handle what style of plotting to use:
            if color_by_sim is None:
                if n_sims <= max_sims:
                    color_by_sim = True
                else:
                    color_by_sim = False

            # Handle what to plot
            if to_plot is None:
                if color_by_sim:
                    to_plot = cvd.get_scen_plots()
                else:
                    to_plot = cvd.get_sim_plots()

            # Handle colors
            if colors is None:
                if color_by_sim:
                    colors = sc.gridcolors(ncolors=n_sims)
                else:
                    colors = [None] * n_sims  # So we can iterate over it
            else:
                colors = [colors] * n_sims  # Again, for iteration

            # Handle alpha if not using colors
            if alpha_range is None:
                if color_by_sim:
                    alpha_range = [
                        0.8, 0.8
                    ]  # We're using color to distinguish sims, so don't need alpha
                else:
                    alpha_range = [0.8, 0.3
                                   ]  # We're using alpha to distinguish sims
            alphas = np.linspace(alpha_range[0], alpha_range[1], n_sims)

            # Plot
            for s, ind in enumerate(inds):
                sim = self.sims[ind]

                final_plot = (s == n_sims - 1
                              )  # Check if this is the final plot

                # Handle the legend and labels
                if final_plot:
                    merged_show_args = show_args
                    kwargs['do_show'] = orig_show
                    kwargs['setylim'] = orig_setylim
                else:
                    merged_show_args = False  # Only show things like data the last time it's plotting
                    kwargs[
                        'do_show'] = False  # On top of that, don't show the plot at all unless it's the last time
                    kwargs['setylim'] = False

                # Optionally set the label for the first max_sims sims
                if color_by_sim is True and s < max_sims:
                    if labels is None:
                        merged_labels = sim.label
                    else:
                        merged_labels = labels[s]
                elif final_plot and not color_by_sim:
                    merged_labels = labels
                else:
                    merged_labels = ''

                # Actually plot
                merged_plot_args = sc.mergedicts(
                    {'alpha': alphas[s]},
                    plot_args)  # Need a new variable to avoid overwriting
                fig = sim.plot(fig=fig,
                               to_plot=to_plot,
                               colors=colors[s],
                               labels=merged_labels,
                               plot_args=merged_plot_args,
                               show_args=merged_show_args,
                               **kwargs)

        return fig
예제 #10
0
def test_gof(doplot=False):
    n = 100
    noise1 = 0.1
    noise2 = 0.5
    noise3 = 2
    normalize_data = True
    normalize = True
    ylim = True
    minval = 0.1
    subtract_min = True
    
    true_unif   = pl.rand(n)
    true_normal = pl.randn(n) + 5
    
    pairs = sc.objdict({
        'uniform_low': [true_unif,
                        true_unif + noise1*pl.rand(n)],
        'uniform_med': [true_unif,
                        true_unif + noise2*pl.rand(n)],
        'uniform_high': [true_unif,
                         true_unif + noise3*pl.rand(n)],
        'uniform_mul': [true_unif,
                         true_unif * 2*pl.rand(n)],
        'normal': [true_normal,
                   true_normal + noise2*pl.randn(n)],
        })
    
    # Remove any DC offset, while ensuring it doesn't go negative
    for pairkey,pair in pairs.items():
        actual = sc.dcp(pair[0])
        predicted = sc.dcp(pair[1])
        
        if subtract_min:
            minmin = min(actual.min(), predicted.min())
            actual    += minval - minmin
            predicted += minval - minmin
        
        if normalize_data:
            a_mean = actual.mean()
            p_mean = predicted.mean()
            if a_mean > p_mean:
                predicted += a_mean - p_mean
            elif a_mean < p_mean:
                actual += p_mean - a_mean
            pairs[pairkey][0] = actual
            pairs[pairkey][1] = predicted
    
    for pair in pairs.values():
        print(pair[0].mean())
        print(pair[1].mean())
    
    methods = [
        'mean fractional',
        'mean absolute',
        'median fractional',
        'median absolute',
        'max_error',
        'mean_absolute_error',
        'mean_squared_error',
        'mean_squared_error',
        'mean_squared_log_error',
        'median_absolute_error',
        # 'r2_score',
        'mean_poisson_deviance',
        'mean_gamma_deviance',
        ]
    
    npairs = len(pairs)
    nmethods = len(methods)
    
    results = pl.zeros((npairs, nmethods))
    for m,method in enumerate(methods):
        print(f'\nWorking on method {method}:')
        for p,pairkey,pair in pairs.enumitems():
            print(f'    Working on data {pairkey}')
            actual = pair[0]
            predicted = pair[1]
            try:
                results[p,m] = pe.gof(actual, predicted, estimator=method)
            except Exception as E:
                print(' '*10 + f'Failed: {str(E)}')
            if pl.isnan(results[p,m]):
                print(' '*10 + f'Returned NaN: {method}')
    
    if normalize:
        results = results/results.max(axis=0)
                
    if doplot:
        pl.figure(figsize=hifigsize)
        pl.subplots_adjust(**axis_args)
        colors = sc.gridcolors(nmethods)
        for p,pairkey,pair in pairs.enumitems():
            
            lastrow = (p == npairs-1)
            
            pl.subplot(npairs, 2, p*2+1)
            actual = pair[0]
            predicted = pair[1]
            pl.scatter(actual, predicted)
            pl.title(f"Data for {pairkey}")
            pl.ylabel('Predicted')
            if lastrow:
                pl.xlabel('Actual')
            
            pl.subplot(npairs, 2, p*2+2)
            for m,method in enumerate(methods):
                pl.bar(m, results[p,m], facecolor=colors[m], label=f'{m}={method}')
            pl.gca().set_xticks(pl.arange(m+1))
            pl.ylabel('Goodness of fit')
            pl.title(f"Normalized GOF for {pairkey}")
            if lastrow: 
                pl.xlabel('Estimator')
                pl.legend()
            if ylim:
                pl.ylim([0,1])
            
    return results
예제 #11
0
def plot():

    # Create the figure
    fig = pl.figure(num='Fig. 1: Calibration', figsize=(24, 20))
    tx1, ty1 = 0.005, 0.97
    tx2, ty2 = 0.545, 0.66
    ty3 = 0.34
    fsize = 40
    pl.figtext(tx1, ty1, 'a', fontsize=fsize)
    pl.figtext(tx1, ty2, 'b', fontsize=fsize)
    pl.figtext(tx2, ty1, 'c', fontsize=fsize)
    pl.figtext(tx2, ty2, 'd', fontsize=fsize)
    pl.figtext(tx1, ty3, 'e', fontsize=fsize)
    pl.figtext(tx2, ty3, 'f', fontsize=fsize)

    #%% Fig. 1A: diagnoses
    x0, y0, dx, dy = 0.055, 0.73, 0.47, 0.24
    ax1 = pl.axes([x0, y0, dx, dy])
    format_ax(ax1, base_sim)
    plotter('cum_diagnoses',
            sims,
            ax1,
            calib=True,
            label='Model',
            ylabel='Cumulative diagnoses')
    pl.legend(loc='lower right', frameon=False)

    #%% Fig. 1B: deaths
    y0b = 0.42
    ax2 = pl.axes([x0, y0b, dx, dy])
    format_ax(ax2, base_sim)
    plotter('cum_deaths',
            sims,
            ax2,
            calib=True,
            label='Model',
            ylabel='Cumulative deaths')
    pl.legend(loc='lower right', frameon=False)

    #%% Fig. 1A-B inserts (histograms)

    agehists = []

    for s, sim in enumerate(sims):
        agehist = sim['analyzers'][0]
        if s == 0:
            age_data = agehist.data
        agehists.append(agehist.hists[-1])

    # Observed data
    x = age_data['age'].values
    pos = age_data['cum_diagnoses'].values
    death = age_data['cum_deaths'].values

    # Model outputs
    mposlist = []
    mdeathlist = []
    for hists in agehists:
        mposlist.append(hists['diagnosed'])
        mdeathlist.append(hists['dead'])
    mposarr = np.array(mposlist)
    mdeatharr = np.array(mdeathlist)

    low_q = 0.025
    high_q = 0.975
    mpbest = pl.median(mposarr, axis=0)
    mplow = pl.quantile(mposarr, q=low_q, axis=0)
    mphigh = pl.quantile(mposarr, q=high_q, axis=0)
    mdbest = pl.median(mdeatharr, axis=0)
    mdlow = pl.quantile(mdeatharr, q=low_q, axis=0)
    mdhigh = pl.quantile(mdeatharr, q=high_q, axis=0)

    w = 4
    off = 2

    # Insets
    x0s, y0s, dxs, dys = 0.11, 0.84, 0.17, 0.13
    ax1s = pl.axes([x0s, y0s, dxs, dys])
    c1 = [0.3, 0.3, 0.6]
    c2 = [0.6, 0.7, 0.9]
    xx = x + w - off
    pl.bar(x - off, pos, width=w, label='Data', facecolor=c1)
    pl.bar(xx, mpbest, width=w, label='Model', facecolor=c2)
    for i, ix in enumerate(xx):
        pl.plot([ix, ix], [mplow[i], mphigh[i]], c='k')
    ax1s.set_xticks(np.arange(0, 81, 20))
    pl.xlabel('Age')
    pl.ylabel('Cases')
    sc.boxoff(ax1s)
    pl.legend(frameon=False, bbox_to_anchor=(0.7, 1.1))

    y0sb = 0.53
    ax2s = pl.axes([x0s, y0sb, dxs, dys])
    c1 = [0.5, 0.0, 0.0]
    c2 = [0.9, 0.4, 0.3]
    pl.bar(x - off, death, width=w, label='Data', facecolor=c1)
    pl.bar(x + w - off, mdbest, width=w, label='Model', facecolor=c2)
    for i, ix in enumerate(xx):
        pl.plot([ix, ix], [mdlow[i], mdhigh[i]], c='k')
    ax2s.set_xticks(np.arange(0, 81, 20))
    pl.xlabel('Age')
    pl.ylabel('Deaths')
    sc.boxoff(ax2s)
    pl.legend(frameon=False)
    sc.boxoff(ax1s)

    #%% Fig. 1C: infections
    x0, dx = 0.60, 0.38
    ax3 = pl.axes([x0, y0, dx, dy])
    format_ax(ax3, sim)

    # Plot SCAN data
    pop_size = 2.25e6
    scan = pd.read_csv(scan_file)
    for i, r in scan.iterrows():
        label = "Data" if i == 0 else None
        ts = np.mean(sim.day(r['since'], r['to']))
        low = r['lower'] * pop_size
        high = r['upper'] * pop_size
        mean = r['mean'] * pop_size
        ax3.plot([ts] * 2, [low, high], alpha=1.0, color='k', zorder=1000)
        ax3.plot(ts,
                 mean,
                 'o',
                 markersize=7,
                 color='k',
                 alpha=0.5,
                 label=label,
                 zorder=1000)

    # Plot simulation
    plotter('cum_infections',
            sims,
            ax3,
            calib=True,
            label='Cumulative\ninfections\n(modeled)',
            ylabel='Infections')
    plotter('n_infectious',
            sims,
            ax3,
            calib=True,
            label='Active\ninfections\n(modeled)',
            ylabel='Infections',
            flabel=False)
    pl.legend(loc='upper left', frameon=False)
    pl.ylim([0, 130e3])
    plot_intervs(sim)

    #%% Fig. 1C: R_eff
    ax4 = pl.axes([x0, y0b, dx, dy])
    format_ax(ax4, sim, key='r_eff')
    plotter('r_eff',
            sims,
            ax4,
            calib=True,
            label='$R_{eff}$ (modeled)',
            ylabel=r'Effective reproduction number')
    pl.axhline(1, linestyle='--', lw=3, c='k', alpha=0.5)
    pl.legend(loc='upper right', frameon=False)
    plot_intervs(sim)

    #%% Fig. 1E

    # Do the plotting
    pl.subplots_adjust(left=0.04,
                       right=0.52,
                       bottom=0.03,
                       top=0.35,
                       wspace=0.12,
                       hspace=0.50)

    for i, k in enumerate(keys):
        eax = pl.subplot(2, 2, i + 1)

        c1 = [0.2, 0.5, 0.8]
        c2 = [1.0, 0.5, 0.0]
        c3 = [0.1, 0.6, 0.1]
        sns.kdeplot(df1[k],
                    shade=1,
                    linewidth=3,
                    label='',
                    color=c1,
                    alpha=0.5)
        sns.kdeplot(df2[k],
                    shade=0,
                    linewidth=3,
                    label='',
                    color=c2,
                    alpha=0.5)

        pl.title(mapping[k])
        pl.xlabel('')
        pl.yticks([])
        if not i % 4:
            pl.ylabel('Density')

        yfactor = 1.3
        yl = pl.ylim()
        pl.ylim([yl[0], yl[1] * yfactor])

        m1 = np.median(df1[k])
        m2 = np.median(df2[k])
        m1std = df1[k].std()
        m2std = df2[k].std()
        pl.axvline(m1, c=c1, ymax=0.9, lw=3, linestyle='--')
        pl.axvline(m2, c=c2, ymax=0.9, lw=3, linestyle='--')

        def fmt(lab, val, std=-1):
            if val < 0.1:
                valstr = f'{lab} = {val:0.4f}'
            elif val < 1.0:
                valstr = f'{lab} = {val:0.2f}±{std:0.2f}'
            else:
                valstr = f'{lab} = {val:0.1f}±{std:0.1f}'
            if std < 0:
                valstr = valstr.split('±')[0]  # Discard STD if not used
            return valstr

        if k.startswith('bc'):
            pl.xlim([0, 100])
        elif k == 'beta':
            pl.xlim([3, 5])
        elif k.startswith('tn'):
            pl.xlim([0, 50])
        else:
            raise Exception(f'Please assign key {k}')

        xl = pl.xlim()
        xfmap = dict(
            beta=0.15,
            bc_wc1=0.30,
            bc_lf=0.35,
            tn=0.55,
        )

        xf = xfmap[k]
        x0 = xl[0] + xf * (xl[1] - xl[0])

        ypos1 = yl[1] * 0.97
        ypos2 = yl[1] * 0.77
        ypos3 = yl[1] * 0.57

        if k == 'beta':  # Use 2 s.f. instead of 1
            pl.text(x0, ypos1, f'M: {m1:0.2f} ± {m1std:0.2f}', c=c1)
            pl.text(x0, ypos2, f'N: {m2:0.2f} ± {m2std:0.2f}', c=c2)
            pl.text(x0,
                    ypos3,
                    rf'$\Delta$: {(m2-m1):0.2f} ± {(m1std+m2std):0.2f}',
                    c=c3)
        else:
            pl.text(x0, ypos1, f'M: {m1:0.1f} ± {m1std:0.1f}', c=c1)
            pl.text(x0, ypos2, f'N: {m2:0.1f} ± {m2std:0.1f}', c=c2)
            pl.text(x0,
                    ypos3,
                    rf'$\Delta$: {(m2-m1):0.1f} ± {(m1std+m2std):0.1f}',
                    c=c3)

        sc.boxoff(ax=eax)

    #%% Fig. 1F: SafeGraph
    x0, y0c, dyc = 0.60, 0.03, 0.30
    ax5 = pl.axes([x0, y0c, dx, dyc])
    format_ax(ax5, sim, key='r_eff')
    fn = safegraph_file
    df = pd.read_csv(fn)
    week = df['week']
    days = sim.day(week.values.tolist())
    s = df['p.tot.schools'].values * 100
    w = df['p.tot.no.schools'].values * 100

    # From Fig. 2
    colors = sc.gridcolors(5)
    wcolor = colors[3]  # Work color/community
    scolor = colors[1]  # School color

    pl.plot(days,
            w,
            'd-',
            c=wcolor,
            markersize=15,
            lw=3,
            alpha=0.9,
            label='Workplace and\ncommunity mobility data')
    pl.plot(days,
            s,
            'd-',
            c=scolor,
            markersize=15,
            lw=3,
            alpha=0.9,
            label='School mobility data')
    sc.setylim()
    xmin, xmax = ax5.get_xlim()
    ax5.set_xticks(np.arange(xmin, xmax, day_stride))
    pl.ylabel('Relative mobility (%)')
    pl.legend(loc='upper right', frameon=False)
    plot_intervs(sim)

    return fig
예제 #12
0
파일: sim.py 프로젝트: haohu1/covasim
    def plot(self,
             to_plot=None,
             do_save=None,
             fig_path=None,
             fig_args=None,
             plot_args=None,
             scatter_args=None,
             axis_args=None,
             as_dates=True,
             interval=None,
             dateformat=None,
             font_size=18,
             font_family=None,
             use_grid=True,
             use_commaticks=True,
             do_show=True,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Args:
            to_plot (dict): Nested dict of results to plot; see default_sim_plots for structure
            do_save (bool or str): Whether or not to save the figure. If a string, save to that filename.
            fig_path (str): Path to save the figure
            fig_args (dict): Dictionary of kwargs to be passed to pl.figure()
            plot_args (dict): Dictionary of kwargs to be passed to pl.plot()
            scatter_args (dict): Dictionary of kwargs to be passed to pl.scatter()
            axis_args (dict): Dictionary of kwargs to be passed to pl.subplots_adjust()
            as_dates (bool): Whether to plot the x-axis as dates or time points
            interval (int): Interval between tick marks
            dateformat (str): Date string format, e.g. '%B %d'
            font_size (int): Size of the font
            font_family (str): Font face
            use_grid (bool): Whether or not to plot gridlines
            use_commaticks (bool): Plot y-axis with commas rather than scientific notation
            do_show (bool): Whether or not to show the figure
            verbose (bool): Display a bit of extra information

        Returns:
            fig: Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        sc.printv('Plotting...', 1, verbose)

        if to_plot is None:
            to_plot = default_sim_plots
        to_plot = sc.odict(to_plot)  # In case it's supplied as a dict

        # Handle input arguments -- merge user input with defaults
        fig_args = sc.mergedicts({'figsize': (16, 12)}, fig_args)
        plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args)
        scatter_args = sc.mergedicts({'s': 150, 'marker': 's'}, scatter_args)
        axis_args = sc.mergedicts(
            {
                'left': 0.1,
                'bottom': 0.05,
                'right': 0.9,
                'top': 0.97,
                'wspace': 0.2,
                'hspace': 0.25
            }, axis_args)

        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        if font_family:
            pl.rcParams['font.family'] = font_family

        res = self.results  # Shorten since heavily used

        # Plot everything

        colors = sc.gridcolors(max([len(tp) for tp in to_plot.values()]))

        # Define the data mapping. Must be here since uses functions
        if self.data is not None and len(self.data):
            data_mapping = {
                'cum_exposed': pl.cumsum(self.data['new_infections']),
                'cum_diagnosed': pl.cumsum(self.data['new_positives']),
                'cum_tested': pl.cumsum(self.data['new_tests']),
                'infections': self.data['new_infections'],
                'tests': self.data['new_tests'],
                'diagnoses': self.data['new_positives'],
            }
        else:
            data_mapping = {}

        for p, title, keylabels in to_plot.enumitems():
            ax = pl.subplot(2, 1, p + 1)
            for i, key, label in keylabels.enumitems():
                this_color = colors[i]
                y = res[key].values
                pl.plot(res['t'], y, label=label, **plot_args, c=this_color)
                if key in data_mapping:
                    pl.scatter(self.data['day'],
                               data_mapping[key],
                               c=[this_color],
                               **scatter_args)
            if self.data is not None and len(self.data):
                pl.scatter(pl.nan,
                           pl.nan,
                           c=[(0, 0, 0)],
                           label='Data',
                           **scatter_args)

            pl.grid(use_grid)
            cvu.fixaxis(self)
            if use_commaticks:
                sc.commaticks()
            pl.title(title)

            # Optionally reset tick marks (useful for e.g. plotting weeks/months)
            if interval:
                xmin, xmax = ax.get_xlim()
                ax.set_xticks(pl.arange(xmin, xmax + 1, interval))

            # Set xticks as dates
            if as_dates:
                xticks = ax.get_xticks()
                xticklabels = self.inds2dates(xticks, dateformat=dateformat)
                ax.set_xticklabels(xticklabels)

            # Plot interventions
            for intervention in self['interventions']:
                intervention.plot(self, ax)

        # Ensure the figure actually renders or saves
        if do_save:
            if fig_path is None:  # No figpath provided - see whether do_save is a figpath
                if isinstance(do_save, str):
                    fig_path = do_save  # It's a string, assume it's a filename
                else:
                    fig_path = 'covasim.png'  # Just give it a default name
            fig_path = sc.makefilepath(
                fig_path)  # Ensure it's valid, including creating the folder
            pl.savefig(fig_path)

        if do_show:
            pl.show()
        else:
            pl.close(fig)

        return fig
예제 #13
0
    testdelay=base_cum_diag,
    trtime=
    base_cum_diag,  # Denominator is people you start contact tracing from, not those traced
)

pointcolor = [0.1, 0.4, 0.0]  # High mobility, high intervention
pointcolor2 = 'k'  # Status quo, darker
c1 = [0.4, 0.4, 0.4]
c2 = [0.2, 0.2, 0.2]

# Plot each seed (eind) separately
sepinds = 0
if sepinds:
    einds = df1['eind'].values
    eis = np.unique(einds)
    cols = sc.gridcolors(len(eis))
if not sepinds:
    eis = [0]
    cols = [c1]

#%% Bottom row: surface plots
print('Calculating surfaces...')

df2 = cv.load(dffile2)

bottom = pl.cm.get_cmap('Oranges', 128)
top = pl.cm.get_cmap('Blues_r', 128)
newcolors = np.vstack((top(np.linspace(0, 1,
                                       128)), bottom(np.linspace(0, 1, 128))))
newcmp = mpl.colors.ListedColormap(newcolors, name='OrangeBlue')
예제 #14
0
def run_sim(sim_pars=None, epi_pars=None, show_animation=False, verbose=True):
    ''' Create, run, and plot everything '''

    err = ''

    try:
        # Fix up things that JavaScript mangles
        orig_pars = cv.make_pars()
        defaults = get_defaults(merge=True)
        web_pars = {}
        web_pars['verbose'] = verbose  # Control verbosity here

        for key, entry in {**sim_pars, **epi_pars}.items():
            print(key, entry)

            best = defaults[key]['best']
            minval = defaults[key]['min']
            maxval = defaults[key]['max']

            try:
                web_pars[key] = np.clip(float(entry['best']), minval, maxval)
            except Exception:
                user_key = entry['name']
                user_val = entry['best']
                err1 = f'Could not convert parameter "{user_key}", value "{user_val}"; using default value instead\n'
                print(err1)
                err += err1
                web_pars[key] = best
            if key in sim_pars: sim_pars[key]['best'] = web_pars[key]
            else: epi_pars[key]['best'] = web_pars[key]

        # Convert durations
        web_pars['dur'] = sc.dcp(
            orig_pars['dur'])  # This is complicated, so just copy it
        web_pars['dur']['exp2inf']['par1'] = web_pars.pop('web_exp2inf')
        web_pars['dur']['inf2sym']['par1'] = web_pars.pop('web_inf2sym')
        web_pars['dur']['crit2die']['par1'] = web_pars.pop('web_timetodie')
        web_dur = web_pars.pop('web_dur')
        for key in ['asym2rec', 'mild2rec', 'sev2rec', 'crit2rec']:
            web_pars['dur'][key]['par1'] = web_dur

        # Add the intervention
        web_pars['interventions'] = []
        if web_pars['web_int_day'] is not None:
            web_pars['interventions'] = cv.change_beta(
                days=web_pars.pop('web_int_day'),
                changes=(1 - web_pars.pop('web_int_eff')))

        # Handle CFR -- ignore symptoms and set to 1
        prog_pars = cv.get_default_prognoses(by_age=False)
        web_pars['rel_symp_prob'] = 1.0 / prog_pars.symp_prob
        web_pars['rel_severe_prob'] = 1.0 / prog_pars.severe_prob
        web_pars['rel_crit_prob'] = 1.0 / prog_pars.crit_prob
        web_pars['rel_death_prob'] = web_pars.pop(
            'web_cfr') / prog_pars.death_prob

    except Exception as E:
        err2 = f'Parameter conversion failed! {str(E)}\n'
        print(err2)
        err += err2

    # Create the sim and update the parameters
    try:
        sim = cv.Sim()
        sim['prog_by_age'] = False  # So the user can override this value
        sim['timelimit'] = max_time  # Set the time limit
        if web_pars['seed'] == 0:
            web_pars['seed'] = None  # Reset
        sim.update_pars(web_pars)
    except Exception as E:
        err3 = f'Sim creation failed! {str(E)}\n'
        print(err3)
        err += err3

    if verbose:
        print('Input parameters:')
        print(web_pars)

    # Core algorithm
    try:
        sim.run(do_plot=False)
    except Exception as E:
        err4 = f'Sim run failed! {str(E)}\n'
        print(err4)
        err += err4

    if sim.stopped:
        try:  # Assume it stopped because of the time, but if not, don't worry
            day = sim.stopped['t']
            time_exceeded = f"The simulation stopped on day {day} because run time limit ({sim['timelimit']} seconds) was exceeded. Please reduce the population size and/or number of days simulated."
            err += time_exceeded
        except:
            pass

    # Core plotting
    graphs = []
    try:

        to_plot = sc.dcp(cv.default_sim_plots)
        for p, title, keylabels in to_plot.enumitems():
            fig = go.Figure()
            colors = sc.gridcolors(len(keylabels))
            for i, key, label in keylabels.enumitems():
                this_color = 'rgb(%d,%d,%d)' % (
                    255 * colors[i][0], 255 * colors[i][1], 255 * colors[i][2])
                y = sim.results[key][:]
                fig.add_trace(
                    go.Scatter(x=sim.results['t'][:],
                               y=y,
                               mode='lines',
                               name=label,
                               line_color=this_color))

            if sim['interventions']:
                interv_day = sim['interventions'][0].days[0]
                if interv_day > 0 and interv_day < sim['n_days']:
                    fig.add_shape(
                        dict(type="line",
                             xref="x",
                             yref="paper",
                             x0=interv_day,
                             x1=interv_day,
                             y0=0,
                             y1=1,
                             name='Intervention',
                             line=dict(width=0.5, dash='dash')))
                    fig.update_layout(annotations=[
                        dict(x=interv_day,
                             y=1.07,
                             xref="x",
                             yref="paper",
                             text="Intervention start",
                             showarrow=False)
                    ])

            fig.update_layout(title={'text': title},
                              xaxis_title='Day',
                              yaxis_title='Count',
                              autosize=True)

            output = {'json': fig.to_json(), 'id': str(sc.uuid())}
            d = json.loads(output['json'])
            d['config'] = {'responsive': True}
            output['json'] = json.dumps(d)
            graphs.append(output)

        graphs.append(plot_people(sim))

        if show_animation:
            graphs.append(animate_people(sim))

    except Exception as E:
        err5 = f'Plotting failed! {str(E)}\n'
        print(err5)
        err += err5

    # Create and send output files (base64 encoded content)
    files = {}
    summary = {}
    try:
        datestamp = sc.getdate(dateformat='%Y-%b-%d_%H.%M.%S')

        ss = sim.to_xlsx()
        files['xlsx'] = {
            'filename':
            f'COVASim_results_{datestamp}.xlsx',
            'content':
            'data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,'
            + base64.b64encode(ss.blob).decode("utf-8"),
        }

        json_string = sim.to_json()
        files['json'] = {
            'filename':
            f'COVASim_results_{datestamp}.txt',
            'content':
            'data:application/text;base64,' +
            base64.b64encode(json_string.encode()).decode("utf-8"),
        }

        # Summary output
        summary = {
            'days': sim.npts - 1,
            'cases': round(sim.results['cum_exposed'][-1]),
            'deaths': round(sim.results['cum_deaths'][-1]),
        }
    except Exception as E:
        err6 = f'File saving failed! {str(E)}\n'
        print(err6)
        err += err6

    output = {}
    output['err'] = err
    output['sim_pars'] = sim_pars
    output['epi_pars'] = epi_pars
    output['graphs'] = graphs
    output['files'] = files
    output['summary'] = summary

    return output
예제 #15
0
    def plot(self, keys=None, width=0.8, font_size=18, fig_args=None, axis_args=None, plot_args=None):
        '''
        Plot the fit of the model to the data. For each result, plot the data
        and the model; the difference; and the loss (weighted difference). Also
        plots the loss as a function of time.

        Args:
            keys (list): which keys to plot (default, all)
            width (float): bar width
            font_size (float): size of font
            fig_args (dict): passed to pl.figure()
            axis_args (dict): passed to pl.subplots_adjust()
            plot_args (dict): passed to pl.plot()
        '''

        fig_args  = sc.mergedicts(dict(figsize=(36,22)), fig_args)
        axis_args = sc.mergedicts(dict(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.3, hspace=0.3), axis_args)
        plot_args = sc.mergedicts(dict(lw=4, alpha=0.5, marker='o'), plot_args)
        pl.rcParams['font.size'] = font_size

        if keys is None:
            keys = self.keys + self.custom_keys
        n_keys = len(keys)

        loss_ax = None
        colors = sc.gridcolors(n_keys)
        n_rows = 4

        figs = [pl.figure(**fig_args)]
        pl.subplots_adjust(**axis_args)
        main_ax1 = pl.subplot(n_rows, 2, 1)
        main_ax2 = pl.subplot(n_rows, 2, 2)
        bottom = sc.objdict() # Keep track of the bottoms for plotting cumulative
        bottom.daily = np.zeros(self.sim_npts)
        bottom.cumul = np.zeros(self.sim_npts)
        for k,key in enumerate(keys):
            if key in self.keys: # It's a time series, plot with days and dates
                days      = self.inds.sim[key] # The "days" axis (or not, for custom keys)
                daylabel  = 'Day'
            else: #It's custom, we don't know what it is
                days      = np.arange(len(self.losses[key])) # Just use indices
                daylabel  = 'Index'

            # Cumulative totals can't mix daily and non-daily inputs, so skip custom keys
            if key in self.keys:
                for i,ax in enumerate([main_ax1, main_ax2]):

                    if i == 0:
                        data = self.losses[key]
                        ylabel = 'Daily mismatch'
                        title = f'Daily total mismatch'
                    else:
                        data = np.cumsum(self.losses[key])
                        ylabel = 'Cumulative mismatch'
                        title = f'Cumulative mismatch: {self.mismatch:0.3f}'

                    dates = self.sim_results['date'][days] # Show these with dates, rather than days, as a reference point
                    ax.bar(dates, data, width=width, bottom=bottom[i][self.inds.sim[key]], color=colors[k], label=f'{key}')

                    if i == 0:
                        bottom.daily[self.inds.sim[key]] += self.losses[key]
                    else:
                        bottom.cumul = np.cumsum(bottom.daily)

                    if k == len(self.keys)-1:
                        ax.set_xlabel('Date')
                        ax.set_ylabel(ylabel)
                        ax.set_title(title)
                        ax.legend()

            pl.subplot(n_rows, n_keys, k+1*n_keys+1)
            pl.plot(days, self.pair[key].data, c='k', label='Data', **plot_args)
            pl.plot(days, self.pair[key].sim, c=colors[k], label='Simulation', **plot_args)
            pl.title(key)
            if k == 0:
                pl.ylabel('Time series (counts)')
                pl.legend()

            pl.subplot(n_rows, n_keys, k+2*n_keys+1)
            pl.bar(days, self.diffs[key], width=width, color=colors[k], label='Difference')
            pl.axhline(0, c='k')
            if k == 0:
                pl.ylabel('Differences (counts)')
                pl.legend()

            loss_ax = pl.subplot(n_rows, n_keys, k+3*n_keys+1, sharey=loss_ax)
            pl.bar(days, self.losses[key], width=width, color=colors[k], label='Losses')
            pl.xlabel(daylabel)
            pl.title(f'Total loss: {self.losses[key].sum():0.3f}')
            if k == 0:
                pl.ylabel('Losses')
                pl.legend()

        return figs
예제 #16
0
파일: model.py 프로젝트: haohu1/covasim
    def plot(self,
             do_save=None,
             fig_args=None,
             plot_args=None,
             scatter_args=None,
             axis_args=None,
             as_days=True,
             font_size=18,
             use_grid=True,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Parameters
        ----------
        do_save : bool or str
            Whether or not to save the figure. If a string, save to that filename.

        fig_args : dict
            Dictionary of kwargs to be passed to pl.figure()

        plot_args : dict
            Dictionary of kwargs to be passed to pl.plot()

        as_days : bool
            Whether to plot the x-axis as days or time points

        Returns
        -------
        Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        if verbose:
            print('Plotting...')

        if fig_args is None: fig_args = {'figsize': (26, 16)}
        if plot_args is None: plot_args = {'lw': 3, 'alpha': 0.7}
        if scatter_args is None: scatter_args = {'s': 150, 'marker': 's'}
        if axis_args is None:
            axis_args = {
                'left': 0.1,
                'bottom': 0.05,
                'right': 0.9,
                'top': 0.97,
                'wspace': 0.2,
                'hspace': 0.25
            }

        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size

        res = self.results  # Shorten since heavily used

        # Plot everything
        colors = sc.gridcolors(5)
        to_plot = sc.odict({ # TODO
            'Total counts': sc.odict({'n_susceptible':'Number susceptible',
                                      'n_exposed':'Number exposed',
                                      'n_infectious':'Number infectious',
                                      'cum_diagnosed':'Number diagnosed',
                                    }),
            'Daily counts': sc.odict({'infections':'New infections',
                                      'tests':'Number of tests',
                                      'diagnoses':'New diagnoses',
                                     }),
            })

        data_mapping = {
            'cum_diagnosed': pl.cumsum(self.data['new_positives']),
            'tests': self.data['new_tests'],
            'diagnoses': self.data['new_positives'],
        }

        for p, title, keylabels in to_plot.enumitems():
            pl.subplot(2, 1, p + 1)
            for i, key, label in keylabels.enumitems():
                this_color = colors[i + p]
                y = res[key]
                pl.plot(res['t'], y, label=label, **plot_args, c=this_color)
                if key in data_mapping:
                    pl.scatter(self.data['day'],
                               data_mapping[key],
                               c=[this_color],
                               **scatter_args)
            pl.scatter(pl.nan,
                       pl.nan,
                       c=[(0, 0, 0)],
                       label='Data',
                       **scatter_args)
            pl.grid(use_grid)
            cv.fixaxis(self)
            pl.ylabel('Count')
            pl.xlabel('Days since index case')
            pl.title(title)

        # Ensure the figure actually renders or saves
        if do_save:
            if isinstance(do_save, str):
                filename = do_save  # It's a string, assume it's a filename
            else:
                filename = 'covid_abm_results.png'  # Just give it a default name
            pl.savefig(filename)

        pl.show()

        return fig
# Fonts and sizes
font_size = 36
font_family = 'Libertinus Sans'
pl.rcParams['font.size'] = font_size
pl.rcParams['font.family'] = font_family
fig = pl.figure(figsize=(24,10))
ax   = pl.axes([0.1, 0.11, 0.85, 0.85])

msim = sc.loadobj(f'{resfolder}/uk_sim_FNL.obj')
sim = msim.base_sim
tt = sim.make_transtree()
#tt = sc.loadobj(f'{resfolder}/tt.obj')
layer_keys = list(sim.people.layer_keys())
layer_mapping = {k:i for i,k in enumerate(layer_keys)}
n_layers = len(layer_keys)
colors = sc.gridcolors(n_layers)

layer_counts = np.zeros((sim.npts, n_layers))
for source_ind, target_ind in tt.transmissions:
    dd = tt.detailed[target_ind]
    date = dd['date']
    layer_num = layer_mapping[dd['layer']]
    layer_counts[date, layer_num] += sim.rescale_vec[date]

lockdown1 = [sc.readdate('2020-03-23'),sc.readdate('2020-05-31')]
lockdown2 = [sc.readdate('2020-11-05'),sc.readdate('2020-12-03')]
lockdown3 = [sc.readdate('2021-01-04'),sc.readdate('2021-02-08')]

labels = ['Household', 'School', 'Workplace', 'Community']
for l in range(n_layers):
    ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l])
예제 #18
0
    def plot(self, which=None, n=None, axsize=None, figsize=None):
        '''
        Create a bar plot of the top causes of burden. By default, plots the top
        10 causes of DALYs.
        
        Version: 2018sep27
        '''

        # Set labels
        titles = {
            'dalys': 'Top causes of DALYs',
            'deaths': 'Top causes of mortality',
            'prevalence': 'Most prevalent conditions'
        }

        # Handle options
        if which is None: which = list(titles.keys())
        if n is None: n = 10
        if axsize is None: axsize = (0.65, 0.15, 0.3, 0.8)
        if figsize is None: figsize = (7, 4)
        barw = 0.8

        # Pull out data
        df = sc.dcp(self.data)
        nburdens = df.nrows
        colors = sc.gridcolors(nburdens + 2, asarray=True)[2:]
        colordict = sc.odict()
        for c, cause in enumerate(df[self.colnames['cause']]):
            colordict[cause] = colors[c]

        # Convert to list
        if not isinstance(which, list):
            asarray = False
            whichlist = sc.promotetolist(which)
        else:
            asarray = True
            whichlist = which

        # Loop over each option (may only be one)
        figs = []
        for which in whichlist:
            colname = self.colnames[which]
            try:
                thistitle = titles[which]
                thisxlabel = colname
            except Exception as E:
                errormsg = '"%s" not found, "which" must be one of %s (%s)' % (
                    which, ', '.join(list(titles.keys())), str(E))
                raise Exception(errormsg)

            # Process data
            df.sort(col=colname, reverse=True)
            topdata = df[:n]
            try:
                barvals = hp.arr(topdata[colname])
            except Exception as E:
                for r in range(topdata.nrows):
                    try:
                        float(topdata[colname, r])
                    except Exception as E2:
                        if topdata[colname, r] in ['', None]:
                            errormsg = 'For cause "%s", the "%s" value is missing or empty' % (
                                topdata[self.colnames['cause'], r], colname)
                        else:
                            errormsg = 'For cause "%s", could not convert "%s" value "%s" to number: %s' % (
                                topdata[self.colnames['cause'], r], colname,
                                topdata[colname, r], str(E2))
                        raise Exception(errormsg)
                errormsg = 'An exception was encountered, but could not be reproduced: %s' % str(
                    E)
                raise Exception(errormsg)
            barlabels = topdata[self.colnames['cause']].tolist()

            # Figure out the units
            largestval = barvals[0]
            if largestval > 1e6:
                barvals /= 1e6
                unitstr = ' (millions)'
            elif largestval > 1e3:
                barvals /= 1e3
                unitstr = ' (thousands)'
            else:
                unitstr = ''

            # Create plot
            fig = pl.figure(facecolor='none', figsize=figsize)
            ax = fig.add_axes(axsize)
            ax.set_facecolor('none')
            yaxis = pl.arange(n, 0, -1)
            for i in range(n):
                thiscause = topdata[self.colnames['cause'], i]
                color = colordict[thiscause]
                pl.barh(yaxis[i],
                        barvals[i],
                        height=barw,
                        facecolor=color,
                        edgecolor='none')
            ax.set_yticks(pl.arange(10, 0, -1))
            ax.set_yticklabels(barlabels)
            sc.SIticks(ax=ax, axis='x')
            ax.set_xlabel(thisxlabel + unitstr)
            ax.set_title(thistitle)
            sc.boxoff()
            figs.append(fig)

        if asarray: return figs
        else: return figs[0]
예제 #19
0
def plot_people(people,
                bins=None,
                width=1.0,
                alpha=0.6,
                fig_args=None,
                axis_args=None,
                plot_args=None,
                do_show=None,
                fig=None):
    ''' Plot statistics of a population -- see People.plot() for documentation '''

    # Handle inputs
    if bins is None:
        bins = np.arange(0, 101)

    # Set defaults
    color = [0.1, 0.1, 0.1]  # Color for the age distribution
    n_rows = 4  # Number of rows of plots
    offset = 0.5  # For ensuring the full bars show up
    gridspace = 10  # Spacing of gridlines
    zorder = 10  # So plots appear on top of gridlines

    # Handle other arguments
    fig_args = sc.mergedicts(dict(figsize=(18, 11)), fig_args)
    axis_args = sc.mergedicts(
        dict(left=0.05,
             right=0.95,
             bottom=0.05,
             top=0.95,
             wspace=0.3,
             hspace=0.35), axis_args)
    plot_args = sc.mergedicts(dict(lw=1.5, alpha=0.6, c=color, zorder=10),
                              plot_args)

    # Compute statistics
    min_age = min(bins)
    max_age = max(bins)
    edges = np.append(
        bins, np.inf)  # Add an extra bin to end to turn them into edges
    age_counts = np.histogram(people.age, edges)[0]

    # Create the figure
    if fig is None:
        fig = pl.figure(**fig_args)
    pl.subplots_adjust(**axis_args)

    # Plot age histogram
    pl.subplot(n_rows, 2, 1)
    pl.bar(bins,
           age_counts,
           color=color,
           alpha=alpha,
           width=width,
           zorder=zorder)
    pl.xlim([min_age - offset, max_age + offset])
    pl.xticks(np.arange(0, max_age + 1, gridspace))
    pl.grid(True)
    pl.xlabel('Age')
    pl.ylabel('Number of people')
    pl.title(f'Age distribution ({len(people):n} people total)')

    # Plot cumulative distribution
    pl.subplot(n_rows, 2, 2)
    age_sorted = sorted(people.age)
    y = np.linspace(0, 100, len(age_sorted))  # Percentage, not hard-coded!
    pl.plot(age_sorted, y, '-', **plot_args)
    pl.xlim([0, max_age])
    pl.ylim([0, 100])  # Percentage
    pl.xticks(np.arange(0, max_age + 1, gridspace))
    pl.yticks(np.arange(0, 101, gridspace))  # Percentage
    pl.grid(True)
    pl.xlabel('Age')
    pl.ylabel('Cumulative proportion (%)')
    pl.title(
        f'Cumulative age distribution (mean age: {people.age.mean():0.2f} years)'
    )

    # Calculate contacts
    lkeys = people.layer_keys()
    n_layers = len(lkeys)
    contact_counts = sc.objdict()
    for lk in lkeys:
        layer = people.contacts[lk]
        p1ages = people.age[layer['p1']]
        p2ages = people.age[layer['p2']]
        contact_counts[lk] = np.histogram(p1ages, edges)[0] + np.histogram(
            p2ages, edges)[0]

    # Plot contacts
    layer_colors = sc.gridcolors(n_layers)
    share_ax = None
    for w, w_type in enumerate(['total', 'percapita', 'weighted'
                                ]):  # Plot contacts in different ways
        for i, lk in enumerate(lkeys):
            if w_type == 'total':
                weight = 1
                total_contacts = 2 * len(
                    people.contacts[lk])  # x2 since each contact is undirected
                ylabel = 'Number of contacts'
                title = f'Total contacts for layer "{lk}": {total_contacts:n}'
            elif w_type == 'percapita':
                weight = np.divide(1.0, age_counts, where=age_counts > 0)
                mean_contacts = 2 * len(people.contacts[lk]) / len(
                    people)  # Factor of 2 since edges are bi-directional
                ylabel = 'Per capita number of contacts'
                title = f'Mean contacts for layer "{lk}": {mean_contacts:0.2f}'
            elif w_type == 'weighted':
                weight = people.pars['beta_layer'][lk] * people.pars['beta']
                total_weight = np.round(weight * 2 * len(people.contacts[lk]))
                ylabel = 'Weighted number of contacts'
                title = f'Total weight for layer "{lk}": {total_weight:n}'

            ax = pl.subplot(n_rows,
                            n_layers,
                            n_layers * (w + 1) + i + 1,
                            sharey=share_ax)
            pl.bar(bins,
                   contact_counts[lk] * weight,
                   color=layer_colors[i],
                   width=width,
                   zorder=zorder,
                   alpha=alpha)
            pl.xlim([min_age - offset, max_age + offset])
            pl.xticks(np.arange(0, max_age + 1, gridspace))
            pl.grid(True)
            pl.xlabel('Age')
            pl.ylabel(ylabel)
            pl.title(title)
            if w_type == 'weighted':
                share_ax = ax  # Update shared axis

    cvset.handle_show(do_show)

    return fig
예제 #20
0
def plot():
    fig = pl.figure(num='Fig. 2: Transmission dynamics', figsize=(20,14))
    piey, tsy, r3y = 0.68, 0.50, 0.07
    piedx, tsdx, r3dx = 0.2, 0.9, 0.25
    piedy, tsdy, r3dy = 0.2, 0.47, 0.35
    pie1x, pie2x = 0.12, 0.65
    tsx = 0.07
    dispx, cumx, sympx = tsx, 0.33+tsx, 0.66+tsx
    ts_ax   = pl.axes([tsx, tsy, tsdx, tsdy])
    pie_ax1 = pl.axes([pie1x, piey, piedx, piedy])
    pie_ax2 = pl.axes([pie2x, piey, piedx, piedy])
    symp_ax = pl.axes([sympx, r3y, r3dx, r3dy])
    disp_ax = pl.axes([dispx, r3y, r3dx, r3dy])
    cum_ax  = pl.axes([cumx, r3y, r3dx, r3dy])

    off = 0.06
    txtdispx, txtcumx, txtsympx = dispx-off, cumx-off, sympx-off+0.02
    tsytxt = tsy+tsdy
    r3ytxt = r3y+r3dy
    labelsize = 40-wf
    pl.figtext(txtdispx, tsytxt, 'a', fontsize=labelsize)
    pl.figtext(txtdispx, r3ytxt, 'b', fontsize=labelsize)
    pl.figtext(txtcumx,  r3ytxt, 'c', fontsize=labelsize)
    pl.figtext(txtsympx, r3ytxt, 'd', fontsize=labelsize)


    #%% Fig. 2A -- Time series plot

    layer_keys = list(sim.layer_keys())
    layer_mapping = {k:i for i,k in enumerate(layer_keys)}
    n_layers = len(layer_keys)
    colors = sc.gridcolors(n_layers)

    layer_counts = np.zeros((sim.npts, n_layers))
    for source_ind, target_ind in tt.count_transmissions():
        dd = tt.detailed[target_ind]
        date = dd['date']
        layer_num = layer_mapping[dd['layer']]
        layer_counts[date, layer_num] += sim.rescale_vec[date]

    mar12 = cv.date('2020-03-12')
    mar23 = cv.date('2020-03-23')
    mar12d = sim.day(mar12)
    mar23d = sim.day(mar23)

    labels = ['Household', 'School', 'Workplace', 'Community', 'LTCF']
    for l in range(n_layers):
        ts_ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l])
    sc.setylim(ax=ts_ax)
    sc.boxoff(ax=ts_ax)
    ts_ax.set_ylabel('Transmissions per day')
    ts_ax.set_xlim([sc.readdate('2020-01-18'), sc.readdate('2020-06-09')])
    ts_ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))
    ts_ax.set_xticks([sim.date(d, as_date=True) for d in np.arange(0, sim.day('2020-06-09'), 14)])
    ts_ax.legend(frameon=False, bbox_to_anchor=(0.85,0.1))

    color = [0.2, 0.2, 0.2]
    ts_ax.axvline(mar12, c=color, linestyle='--', alpha=0.4, lw=3)
    ts_ax.axvline(mar23, c=color, linestyle='--', alpha=0.4, lw=3)
    yl = ts_ax.get_ylim()
    labely = yl[1]*1.015
    ts_ax.text(mar12, labely, 'Schools close                     ', color=color, alpha=0.9, style='italic', horizontalalignment='center')
    ts_ax.text(mar23, labely, '                   Stay-at-home', color=color, alpha=0.9, style='italic', horizontalalignment='center')


    #%% Fig. 2A inset -- Pie charts

    pre_counts = layer_counts[0:mar12d, :].sum(axis=0)
    post_counts = layer_counts[mar23d:, :].sum(axis=0)
    pre_counts = pre_counts/pre_counts.sum()*100
    post_counts = post_counts/post_counts.sum()*100

    lpre = [
        f'Household\n{pre_counts[0]:0.1f}%',
        f'School\n{pre_counts[1]:0.1f}% ',
        f'Workplace\n{pre_counts[2]:0.1f}%    ',
        f'Community\n{pre_counts[3]:0.1f}%',
        f'LTCF\n{pre_counts[4]:0.1f}%',
    ]

    lpost = [
        f'Household\n{post_counts[0]:0.1f}%',
        f'School\n{post_counts[1]:0.1f}%',
        f'Workplace\n{post_counts[2]:0.1f}%',
        f'Community\n{post_counts[3]:0.1f}%',
        f'LTCF\n{post_counts[4]:0.1f}%',
    ]

    pie_ax1.pie(pre_counts, colors=colors, labels=lpre, **pieargs)
    pie_ax2.pie(post_counts, colors=colors, labels=lpost, **pieargs)

    pie_ax1.text(0, 1.75, 'Transmissions by layer\nbefore schools closed', style='italic', horizontalalignment='center')
    pie_ax2.text(0, 1.75, 'Transmissions by layer\nafter stay-at-home', style='italic', horizontalalignment='center')


    #%% Fig. 2B -- histogram by overdispersion

    # Process targets
    n_targets = tt.count_targets(end_day=mar12)

    # Handle bins
    max_infections = n_targets.max()
    edges = np.arange(0, max_infections+2)

    # Analysis
    counts = np.histogram(n_targets, edges)[0]
    bins = edges[:-1] # Remove last bin since it's an edge
    norm_counts = counts/counts.sum()
    raw_counts = counts*bins
    total_counts = raw_counts/raw_counts.sum()*100
    n_bins = len(bins)
    index = np.linspace(0, 100, len(n_targets))
    sorted_arr = np.sort(n_targets)
    sorted_sum = np.cumsum(sorted_arr)
    sorted_sum = sorted_sum/sorted_sum.max()*100
    change_inds = sc.findinds(np.diff(sorted_arr) != 0)

    pl.set_cmap('Spectral_r')
    sscolors = sc.vectocolor(n_bins)

    width = 1.0
    for i in range(n_bins):
        disp_ax.bar(bins[i], total_counts[i], width=width, facecolor=sscolors[i])
    disp_ax.set_xlabel('Number of transmissions per case')
    disp_ax.set_ylabel('Proportion of transmissions (%)')
    sc.boxoff()
    disp_ax.set_xlim([0.5, 32.5])
    disp_ax.set_xticks(np.arange(0, 32.5, 4))
    sc.boxoff(ax=disp_ax)

    dpie_ax = pl.axes([dispx+0.05, 0.20, 0.2, 0.2])
    trans1 = total_counts[1:3].sum()
    trans2 = total_counts[3:5].sum()
    trans3 = total_counts[5:8].sum()
    trans4 = total_counts[8:].sum()
    labels = [
        f'1-2:\n{trans1:0.0f}%',
        f' 3-4:\n {trans2:0.0f}%',
        f'5-7: \n{trans3:0.0f}%\n',
        f'>7:  \n{trans4:0.0f}%\n',
        ]
    dpie_args = sc.mergedicts(pieargs, dict(labeldistance=1.2)) # Slightly smaller label distance
    dpie_ax.pie([trans1, trans2, trans3, trans4], labels=labels, colors=sscolors[[0,4,7,12]], **dpie_args)


    #%% Fig. 2C -- cumulative distribution function

    rev_ind = 100 - index
    n_change_inds = len(change_inds)
    change_bins = bins[counts>0][1:]
    for i in range(n_change_inds):
        ib = int(change_bins[i])
        ci = change_inds[i]
        ici = index[ci]
        sci = sorted_sum[ci]
        color = sscolors[ib]
        if i>0:
            cim1 = change_inds[i-1]
            icim1 = index[cim1]
            scim1 = sorted_sum[cim1]
            cum_ax.plot([icim1, ici], [scim1, sci], lw=4, c=color)
        cum_ax.scatter([ici], [sci], s=150, zorder=50-i, c=[color], edgecolor='w', linewidth=0.2)
        if ib<=6 or ib in [8, 10, 25]:
            xoff = 5 - 2*(ib==1) + 3*(ib>=10) + 1*(ib>=20)
            yoff = 2*(ib==1)
            cum_ax.text(ici-xoff, sci+yoff, ib, fontsize=18-wf, color=color)
    cum_ax.set_xlabel('Proportion of primary infections (%)')
    cum_ax.set_ylabel('Proportion of transmissions (%)')
    xmin = -2
    ymin = -2
    cum_ax.set_xlim([xmin, 102])
    cum_ax.set_ylim([ymin, 102])
    sc.boxoff(ax=cum_ax)

    # Draw horizontal lines and annotations
    ancol1 = [0.2, 0.2, 0.2]
    ancol2 = sscolors[0]
    ancol3 = sscolors[6]

    i01 = sc.findlast(sorted_sum==0)
    i20 = sc.findlast(sorted_sum<=20)
    i50 = sc.findlast(sorted_sum<=50)
    cum_ax.plot([xmin, index[i01]], [0, 0], '--', lw=2, c=ancol1)
    cum_ax.plot([xmin, index[i20], index[i20]], [20, 20, ymin], '--', lw=2, c=ancol2)
    cum_ax.plot([xmin, index[i50], index[i50]], [50, 50, ymin], '--', lw=2, c=ancol3)

    # Compute mean number of transmissions for 80% and 50% thresholds
    q80 = sc.findfirst(np.cumsum(total_counts)>20) # Count corresponding to 80% of cumulative infections (100-80)
    q50 = sc.findfirst(np.cumsum(total_counts)>50) # Count corresponding to 50% of cumulative infections
    n80, n50 = [sum(bins[q:]*norm_counts[q:]/norm_counts[q:].sum()) for q in [q80, q50]]

    # Plot annotations
    kw = dict(bbox=dict(facecolor='w', alpha=0.9, lw=0), fontsize=20-wf)
    cum_ax.text(2, 3, f'{index[i01]:0.0f}% of infections\ndo not transmit', c=ancol1, **kw)
    cum_ax.text(8, 23, f'{rev_ind[i20]:0.0f}% of infections cause\n80% of transmissions\n(mean: {n80:0.1f} per infection)', c=ancol2, **kw)
    cum_ax.text(14, 53, f'{rev_ind[i50]:0.0f}% of infections cause\n50% of transmissions\n(mean: {n50:0.1f} per infection)', c=ancol3, **kw)


    #%% Fig. 2D -- histogram by date of symptom onset

    # Calculate
    asymp_count = 0
    symp_counts = {}
    minind = -5
    maxind = 15
    for _, target_ind in tt.transmissions:
        dd = tt.detailed[target_ind]
        date = dd['date']
        delta = sim.rescale_vec[date] # Increment counts by this much
        if dd['s']:
            if tt.detailed[dd['source']]['date'] <= date: # Skip dynamical scaling reinfections
                sdate = dd['s']['date_symptomatic']
                if np.isnan(sdate):
                    asymp_count += delta
                else:
                    ind = int(date - sdate)
                    if ind not in symp_counts:
                        symp_counts[ind] = 0
                    symp_counts[ind] += delta

    # Convert to an array
    xax = np.arange(minind-1, maxind+1)
    sympcounts = np.zeros(len(xax))
    for i,val in symp_counts.items():
        if i<minind:
            ind = 0
        elif i>maxind:
            ind = -1
        else:
            ind = sc.findinds(xax==i)[0]
        sympcounts[ind] += val

    # Plot
    total_count = asymp_count + sympcounts.sum()
    sympcounts = sympcounts/total_count*100
    presymp = sc.findinds(xax<=0)[-1]
    colors = ['#eed15b', '#ee943a', '#c3211a']

    asymp_frac = asymp_count/total_count*100
    pre_frac = sympcounts[:presymp].sum()
    symp_frac = sympcounts[presymp:].sum()
    symp_ax.bar(xax[0]-2, asymp_frac, label='Asymptomatic', color=colors[0])
    symp_ax.bar(xax[:presymp], sympcounts[:presymp], label='Presymptomatic', color=colors[1])
    symp_ax.bar(xax[presymp:], sympcounts[presymp:], label='Symptomatic', color=colors[2])
    symp_ax.set_xlabel('Days since symptom onset')
    symp_ax.set_ylabel('Proportion of transmissions (%)')
    symp_ax.set_xticks([minind-3, 0, 5, 10, maxind])
    symp_ax.set_xticklabels(['Asymp.', '0', '5', '10', f'>{maxind}'])
    sc.boxoff(ax=symp_ax)

    spie_ax = pl.axes([sympx+0.05, 0.20, 0.2, 0.2])
    labels = [f'Asymp-\ntomatic\n{asymp_frac:0.0f}%', f' Presymp-\n tomatic\n {pre_frac:0.0f}%', f'Symp-\ntomatic\n{symp_frac:0.0f}%']
    spie_ax.pie([asymp_frac, pre_frac, symp_frac], labels=labels, colors=colors, **pieargs)

    return fig