示例#1
0
def title_grid_legend(ax, title, grid, commaticks, setylim, legend_args, show_legend=True):
    ''' Plot styling -- set the plot title, add a legend, and optionally add gridlines'''

    # Handle show_legend being in the legend args, since in some cases this is the only way it can get passed
    if 'show_legend' in legend_args:
        show_legend = legend_args.pop('show_legend')
        popped = True
    else:
        popped = False

    # Show the legend
    if show_legend:
        ax.legend(**legend_args)

    # If we removed it from the legend_args dict, put it back now
    if popped:
        legend_args['show_legend'] = show_legend

    # Set the title and gridlines
    ax.set_title(title)
    ax.grid(grid)

    # Set the y axis style
    if setylim:
        ax.set_ylim(bottom=0)
    if commaticks:
        ylims = ax.get_ylim()
        if ylims[1] >= 1000:
            sc.commaticks(ax=ax)

    return
示例#2
0
def format_ax(ax, sim):
    @ticker.FuncFormatter
    def date_formatter(x, pos):
        return (sim['start_day'] + dt.timedelta(days=x)).strftime('%b-%d')
    ax.xaxis.set_major_formatter(date_formatter)
    sc.commaticks()
    # sc.boxoff()
    return
示例#3
0
def title_grid_legend(ax,
                      title,
                      grid,
                      commaticks,
                      setylim,
                      legend_args,
                      show_args,
                      show_legend=True):
    ''' Plot styling -- set the plot title, add a legend, and optionally add gridlines'''

    # Handle show_legend being in the legend args, since in some cases this is the only way it can get passed
    if 'show_legend' in legend_args:
        show_legend = legend_args.pop('show_legend')
        popped = True
    else:
        popped = False

    # Show the legend
    if show_legend and show_args[
            'legend']:  # It's pretty ugly, but there are multiple ways of controlling whether the legend shows

        # Remove duplicate entries
        handles, labels = ax.get_legend_handles_labels()
        unique_inds = np.sort(np.unique(labels, return_index=True)[1])
        handles = [handles[u] for u in unique_inds]
        labels = [labels[u] for u in unique_inds]

        # Actually make legend
        ax.legend(handles=handles, labels=labels, **legend_args)

    # If we removed it from the legend_args dict, put it back now
    if popped:
        legend_args['show_legend'] = show_legend

    # Set the title, gridlines, and color
    ax.set_title(title)

    # Set the y axis style
    if setylim and ax.yaxis.get_scale() != 'log':
        ax.set_ylim(bottom=0)
    if commaticks:
        ylims = ax.get_ylim()
        if ylims[1] >= 1000:
            sc.commaticks(ax=ax)

    # Optionally remove x-axis labels except on bottom plots -- don't use ax.label_outer() since we need to keep the y-labels
    if show_args['outer']:
        lastrow = ax.get_subplotspec().is_last_row()
        if not lastrow:
            for label in ax.get_xticklabels(which="both"):
                label.set_visible(False)
            ax.set_xlabel('')

    return
 def plot(self):
     pl.figure()
     pl.plot(self.t, self.S, label='S')
     pl.plot(self.t, self.E, label='E')
     pl.plot(self.t, self.I, label='I')
     pl.plot(self.t, self.R, label='R')
     pl.legend()
     pl.xlabel('Day')
     pl.ylabel('People')
     sc.setylim()  # Reset y-axis to start at 0
     sc.commaticks()  # Use commas in the y-axis labels
     return
示例#5
0
def format_ax(ax, sim, key=None):
    ''' Format the axes nicely '''
    @ticker.FuncFormatter
    def date_formatter(x, pos):
        return (sim['start_day'] + dt.timedelta(days=x)).strftime('%b-%d')

    ax.xaxis.set_major_formatter(date_formatter)
    if key != 'r_eff':
        sc.commaticks()
    pl.xlim([0, sim['n_days']])
    sc.boxoff()
    return
def format_ax(ax, sim, key=None):
    @ticker.FuncFormatter
    def date_formatter(x, pos):
        return (sim['start_day'] + dt.timedelta(days=int(x))).strftime('%b\n%y')
    ax.xaxis.set_major_formatter(date_formatter)
    if key != 'r_eff':
        sc.commaticks()
    pl.xlim([0, sim['n_days']])
    pl.axvspan(lockdown1[0], lockdown1[1], color='steelblue', alpha=0.2, lw=0)
    pl.axvspan(lockdown2[0], lockdown2[1], color='steelblue', alpha=0.2, lw=0)
    pl.axvspan(lockdown3[0], lockdown3[1], color='lightblue', alpha=0.2, lw=0)

    return
def format_axs(axs, key=None):
    ''' Format axes nicely '''
    @ticker.FuncFormatter
    def date_formatter(x, pos):
        # print(x)
        return (refsim['start_day'] + dt.timedelta(days=x)).strftime('%b-%d')

    for i, ax in enumerate(axs):
        bbox = None if i != 1 else (1.05, 1.05)  # Move legend up a bit
        day_stride = 21
        xmin, xmax = ax.get_xlim()
        ax.set_xticks(np.arange(xmin, xmax, day_stride))
        ax.xaxis.set_major_formatter(date_formatter)
        ax.legend(frameon=False, bbox_to_anchor=bbox)
        sc.boxoff(ax=ax)
        sc.setylim(ax=ax)
        sc.commaticks(ax=ax)
    return
示例#8
0
    def plot(self,
             to_plot=None,
             do_save=None,
             fig_path=None,
             fig_args=None,
             plot_args=None,
             axis_args=None,
             fill_args=None,
             legend_args=None,
             as_dates=True,
             dateformat=None,
             interval=None,
             n_cols=1,
             font_size=18,
             font_family=None,
             grid=True,
             commaticks=True,
             do_show=True,
             sep_figs=False,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Args:
            to_plot     (dict): Dict of results to plot; see default_scen_plots for structure
            do_save     (bool): Whether or not to save the figure
            fig_path    (str):  Path to save the figure
            fig_args    (dict): Dictionary of kwargs to be passed to pl.figure()
            plot_args   (dict): Dictionary of kwargs to be passed to pl.plot()
            axis_args   (dict): Dictionary of kwargs to be passed to pl.subplots_adjust()
            fill_args   (dict): Dictionary of kwargs to be passed to pl.fill_between()
            legend_args (dict): Dictionary of kwargs to be passed to pl.legend()
            as_dates    (bool): Whether to plot the x-axis as dates or time points
            dateformat  (str):  Date string format, e.g. '%B %d'
            interval    (int):  Interval between tick marks
            n_cols      (int):  Number of columns of subpanels to use for subplot
            font_size   (int):  Size of the font
            font_family (str):  Font face
            grid        (bool): Whether or not to plot gridlines
            commaticks  (bool): Plot y-axis with commas rather than scientific notation
            do_show     (bool): Whether or not to show the figure
            sep_figs    (bool): Whether to show separate figures for different results instead of subplots
            verbose     (bool): Display a bit of extra information

        Returns:
            fig: Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        sc.printv('Plotting...', 1, verbose)

        if to_plot is None:
            to_plot = cvd.default_scen_plots
        to_plot = sc.dcp(to_plot)  # In case it's supplied as a dict

        # Handle input arguments -- merge user input with defaults
        fig_args = sc.mergedicts({'figsize': (16, 14)}, fig_args)
        plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args)
        axis_args = sc.mergedicts(
            {
                'left': 0.10,
                'bottom': 0.05,
                'right': 0.95,
                'top': 0.90,
                'wspace': 0.25,
                'hspace': 0.25
            }, axis_args)
        fill_args = sc.mergedicts({'alpha': 0.2}, fill_args)
        legend_args = sc.mergedicts({'loc': 'best'}, legend_args)

        if sep_figs:
            figs = []
        else:
            fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        if font_family:
            pl.rcParams['font.family'] = font_family

        n_rows = np.ceil(len(to_plot) /
                         n_cols)  # Number of subplot rows to have
        for rk, reskey in enumerate(to_plot):
            title = self.base_sim.results[
                reskey].name  # Get the name of this result from the base simulation
            if sep_figs:
                figs.append(pl.figure(**fig_args))
                ax = pl.subplot(111)
            else:
                ax = pl.subplot(n_rows, n_cols, rk + 1)

            resdata = self.results[reskey]

            for scenkey, scendata in resdata.items():

                pl.fill_between(self.tvec, scendata.low, scendata.high,
                                **fill_args)
                pl.plot(self.tvec,
                        scendata.best,
                        label=scendata.name,
                        **plot_args)
                pl.title(title)
                if rk == 0:
                    pl.legend(**legend_args)

                pl.grid(grid)
                if commaticks:
                    sc.commaticks()

                if self.base_sim.data is not None and reskey in self.base_sim.data:
                    data_t = np.array(
                        (self.base_sim.data.index - self.base_sim['start_day'])
                        / np.timedelta64(1, 'D'))
                    pl.plot(data_t, self.base_sim.data[reskey], 'sk',
                            **plot_args)

                # Optionally reset tick marks (useful for e.g. plotting weeks/months)
                if interval:
                    xmin, xmax = ax.get_xlim()
                    ax.set_xticks(pl.arange(xmin, xmax + 1, interval))

                # Set xticks as dates
                if as_dates:

                    @ticker.FuncFormatter
                    def date_formatter(x, pos):
                        return (self.base_sim['start_day'] +
                                dt.timedelta(days=x)).strftime('%b-%d')

                    ax.xaxis.set_major_formatter(date_formatter)
                    if not interval:
                        ax.xaxis.set_major_locator(
                            ticker.MaxNLocator(integer=True))

        # Ensure the figure actually renders or saves
        if do_save:
            if fig_path is None:  # No figpath provided - see whether do_save is a figpath
                fig_path = 'covasim_scenarios.png'  # Just give it a default name
            fig_path = sc.makefilepath(
                fig_path)  # Ensure it's valid, including creating the folder
            pl.savefig(fig_path)

        if do_show:
            pl.show()
        else:
            pl.close(fig)

        return fig
示例#9
0
    def plot(self, to_plot=None, do_save=None, fig_path=None, fig_args=None, plot_args=None,
             axis_args=None, fill_args=None, as_dates=True, interval=None, dateformat=None,
             font_size=18, font_family=None, grid=True, commaticks=True, do_show=True, sep_figs=False,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Args:
            to_plot     (dict): Dict of results to plot; see default_scen_plots for structure
            do_save     (bool): Whether or not to save the figure
            fig_path    (str):  Path to save the figure
            fig_args    (dict): Dictionary of kwargs to be passed to pl.figure()
            plot_args   (dict): Dictionary of kwargs to be passed to pl.plot()
            axis_args   (dict): Dictionary of kwargs to be passed to pl.subplots_adjust()
            fill_args   (dict): Dictionary of kwargs to be passed to pl.fill_between()
            as_dates    (bool): Whether to plot the x-axis as dates or time points
            interval    (int):  Interval between tick marks
            dateformat  (str):  Date string format, e.g. '%B %d'
            font_size   (int):  Size of the font
            font_family (str):  Font face
            grid        (bool): Whether or not to plot gridlines
            commaticks  (bool): Plot y-axis with commas rather than scientific notation
            do_show     (bool): Whether or not to show the figure
            sep_figs    (bool): Whether to show separate figures for different results instead of subplots
            verbose     (bool): Display a bit of extra information

        Returns:
            fig: Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        sc.printv('Plotting...', 1, verbose)

        if to_plot is None:
            to_plot = default_scen_plots
        to_plot = sc.odict(sc.dcp(to_plot)) # In case it's supplied as a dict

        # Handle input arguments -- merge user input with defaults
        fig_args  = sc.mergedicts({'figsize': (16, 12)}, fig_args)
        plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args)
        axis_args = sc.mergedicts({'left': 0.10, 'bottom': 0.05, 'right': 0.95, 'top': 0.90, 'wspace': 0.5, 'hspace': 0.25}, axis_args)
        fill_args = sc.mergedicts({'alpha': 0.2}, fill_args)

        if sep_figs:
            figs = []
        else:
            fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        if font_family:
            pl.rcParams['font.family'] = font_family

        # %% Plotting
        for rk,reskey,title in to_plot.enumitems():
            if sep_figs:
                figs.append(pl.figure(**fig_args))
                ax = pl.subplot(111)
            else:
                ax = pl.subplot(len(to_plot), 1, rk + 1)

            resdata = self.allres[reskey]

            for scenkey, scendata in resdata.items():

                pl.fill_between(self.tvec, scendata.low, scendata.high, **fill_args)
                pl.plot(self.tvec, scendata.best, label=scendata.name, **plot_args)
                pl.title(title)
                if rk == 0:
                    pl.legend(loc='best')

                pl.grid(grid)
                if commaticks:
                    sc.commaticks()

                # Optionally reset tick marks (useful for e.g. plotting weeks/months)
                if interval:
                    xmin,xmax = ax.get_xlim()
                    ax.set_xticks(pl.arange(xmin, xmax+1, interval))

                # Set xticks as dates
                if as_dates:
                    xticks = ax.get_xticks()
                    xticklabels = self.base_sim.inds2dates(xticks, dateformat=dateformat)
                    ax.set_xticklabels(xticklabels)

        # Ensure the figure actually renders or saves
        if do_save:
            if fig_path is None: # No figpath provided - see whether do_save is a figpath
                fig_path = 'covasim_scenarios.png' # Just give it a default name
            fig_path = sc.makefilepath(fig_path) # Ensure it's valid, including creating the folder
            pl.savefig(fig_path)

        if do_show:
            pl.show()
        else:
            pl.close(fig)

        return fig
示例#10
0
文件: sim.py 项目: willf/covasim
    def plot(self,
             to_plot=None,
             do_save=None,
             fig_path=None,
             fig_args=None,
             plot_args=None,
             scatter_args=None,
             axis_args=None,
             legend_args=None,
             as_dates=True,
             dateformat=None,
             interval=None,
             n_cols=1,
             font_size=18,
             font_family=None,
             use_grid=True,
             use_commaticks=True,
             do_show=True,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Args:
            to_plot (dict): Nested dict of results to plot; see default_sim_plots for structure
            do_save (bool or str): Whether or not to save the figure. If a string, save to that filename.
            fig_path (str): Path to save the figure
            fig_args (dict): Dictionary of kwargs to be passed to pl.figure()
            plot_args (dict): Dictionary of kwargs to be passed to pl.plot()
            scatter_args (dict): Dictionary of kwargs to be passed to pl.scatter()
            axis_args (dict): Dictionary of kwargs to be passed to pl.subplots_adjust()
            legend_args (dict): Dictionary of kwargs to be passed to pl.legend()
            as_dates (bool): Whether to plot the x-axis as dates or time points
            dateformat (str): Date string format, e.g. '%B %d'
            interval (int): Interval between tick marks
            n_cols (int): Number of columns of subpanels to use for subplot
            font_size (int): Size of the font
            font_family (str): Font face
            use_grid (bool): Whether or not to plot gridlines
            use_commaticks (bool): Plot y-axis with commas rather than scientific notation
            do_show (bool): Whether or not to show the figure
            verbose (bool): Display a bit of extra information

        Returns:
            fig: Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        sc.printv('Plotting...', 1, verbose)

        if to_plot is None:
            to_plot = cvd.default_sim_plots
        to_plot = sc.odict(to_plot)  # In case it's supplied as a dict

        # Handle input arguments -- merge user input with defaults
        fig_args = sc.mergedicts({'figsize': (16, 14)}, fig_args)
        plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args)
        scatter_args = sc.mergedicts({'s': 70, 'marker': 's'}, scatter_args)
        axis_args = sc.mergedicts(
            {
                'left': 0.1,
                'bottom': 0.05,
                'right': 0.9,
                'top': 0.97,
                'wspace': 0.2,
                'hspace': 0.25
            }, axis_args)
        legend_args = sc.mergedicts({'loc': 'best'}, legend_args)

        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        if font_family:
            pl.rcParams['font.family'] = font_family

        res = self.results  # Shorten since heavily used

        # Plot everything
        n_rows = np.ceil(len(to_plot) /
                         n_cols)  # Number of subplot rows to have
        for p, title, keylabels in to_plot.enumitems():
            ax = pl.subplot(n_rows, n_cols, p + 1)
            for key in keylabels:
                label = res[key].name
                this_color = res[key].color
                y = res[key].values
                pl.plot(res['t'], y, label=label, **plot_args, c=this_color)
                if self.data is not None and key in self.data:
                    data_t = (
                        self.data.index - self['start_day']
                    ) / np.timedelta64(
                        1, 'D'
                    )  # Convert from data date to model output index based on model start date
                    pl.scatter(data_t,
                               self.data[key],
                               c=[this_color],
                               **scatter_args)
            if self.data is not None and len(self.data):
                pl.scatter(pl.nan,
                           pl.nan,
                           c=[(0, 0, 0)],
                           label='Data',
                           **scatter_args)

            pl.legend(**legend_args)
            pl.grid(use_grid)
            sc.setylim()
            if use_commaticks:
                sc.commaticks()
            pl.title(title)

            # Optionally reset tick marks (useful for e.g. plotting weeks/months)
            if interval:
                xmin, xmax = ax.get_xlim()
                ax.set_xticks(pl.arange(xmin, xmax + 1, interval))

            # Set xticks as dates
            if as_dates:

                @ticker.FuncFormatter
                def date_formatter(x, pos):
                    return (self['start_day'] +
                            dt.timedelta(days=x)).strftime('%b-%d')

                ax.xaxis.set_major_formatter(date_formatter)
                if not interval:
                    ax.xaxis.set_major_locator(
                        ticker.MaxNLocator(integer=True))

            # Plot interventions
            for intervention in self['interventions']:
                intervention.plot(self, ax)

        # Ensure the figure actually renders or saves
        if do_save:
            if fig_path is None:  # No figpath provided - see whether do_save is a figpath
                if isinstance(do_save, str):
                    fig_path = do_save  # It's a string, assume it's a filename
                else:
                    fig_path = 'covasim.png'  # Just give it a default name
            fig_path = sc.makefilepath(
                fig_path)  # Ensure it's valid, including creating the folder
            pl.savefig(fig_path)

        if do_show:
            pl.show()
        else:
            pl.close(fig)

        return fig
                    elinewidth=3,
                    capsize=0)

box_ax.set_xticks(x - 0.15)
#box_ax.set_xticklabels(labels)


@ticker.FuncFormatter
def date_formatter(x, pos):
    return (cv.date('2021-01-12') + dt.timedelta(days=x * 7)).strftime('%d-%b')


box_ax.xaxis.set_major_formatter(date_formatter)
pl.ylabel('Estimated daily infections (000s)')
sc.boxoff(ax=box_ax)
sc.commaticks()
box_ax.legend(frameon=False)

# B. Cumulative total infections
width = 0.8  # the width of the bars
x = [0, 1, 2]
data = np.array([
    msims[sn].results['cum_infections'].values[-1] -
    msims[sn].results['cum_infections'].values[sim.day('2021-01-04')]
    for sn in scenarios
])
bar_ax = pl.axes([xgapl + xgapm + dx1, ygapb, dx2, dy])
for sn, scen in enumerate(scenarios):
    bar_ax.bar(x[sn], data[sn] / 1e3, width, color=colors[sn], alpha=1.0)

bar_ax.set_xticklabels(['', 'FNL', 'Primary-only\nPNL', 'Staggered\nPNL'])
示例#12
0
文件: sim.py 项目: haohu1/covasim
    def plot(self,
             to_plot=None,
             do_save=None,
             fig_path=None,
             fig_args=None,
             plot_args=None,
             scatter_args=None,
             axis_args=None,
             as_dates=True,
             interval=None,
             dateformat=None,
             font_size=18,
             font_family=None,
             use_grid=True,
             use_commaticks=True,
             do_show=True,
             verbose=None):
        '''
        Plot the results -- can supply arguments for both the figure and the plots.

        Args:
            to_plot (dict): Nested dict of results to plot; see default_sim_plots for structure
            do_save (bool or str): Whether or not to save the figure. If a string, save to that filename.
            fig_path (str): Path to save the figure
            fig_args (dict): Dictionary of kwargs to be passed to pl.figure()
            plot_args (dict): Dictionary of kwargs to be passed to pl.plot()
            scatter_args (dict): Dictionary of kwargs to be passed to pl.scatter()
            axis_args (dict): Dictionary of kwargs to be passed to pl.subplots_adjust()
            as_dates (bool): Whether to plot the x-axis as dates or time points
            interval (int): Interval between tick marks
            dateformat (str): Date string format, e.g. '%B %d'
            font_size (int): Size of the font
            font_family (str): Font face
            use_grid (bool): Whether or not to plot gridlines
            use_commaticks (bool): Plot y-axis with commas rather than scientific notation
            do_show (bool): Whether or not to show the figure
            verbose (bool): Display a bit of extra information

        Returns:
            fig: Figure handle
        '''

        if verbose is None:
            verbose = self['verbose']
        sc.printv('Plotting...', 1, verbose)

        if to_plot is None:
            to_plot = default_sim_plots
        to_plot = sc.odict(to_plot)  # In case it's supplied as a dict

        # Handle input arguments -- merge user input with defaults
        fig_args = sc.mergedicts({'figsize': (16, 12)}, fig_args)
        plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args)
        scatter_args = sc.mergedicts({'s': 150, 'marker': 's'}, scatter_args)
        axis_args = sc.mergedicts(
            {
                'left': 0.1,
                'bottom': 0.05,
                'right': 0.9,
                'top': 0.97,
                'wspace': 0.2,
                'hspace': 0.25
            }, axis_args)

        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        if font_family:
            pl.rcParams['font.family'] = font_family

        res = self.results  # Shorten since heavily used

        # Plot everything

        colors = sc.gridcolors(max([len(tp) for tp in to_plot.values()]))

        # Define the data mapping. Must be here since uses functions
        if self.data is not None and len(self.data):
            data_mapping = {
                'cum_exposed': pl.cumsum(self.data['new_infections']),
                'cum_diagnosed': pl.cumsum(self.data['new_positives']),
                'cum_tested': pl.cumsum(self.data['new_tests']),
                'infections': self.data['new_infections'],
                'tests': self.data['new_tests'],
                'diagnoses': self.data['new_positives'],
            }
        else:
            data_mapping = {}

        for p, title, keylabels in to_plot.enumitems():
            ax = pl.subplot(2, 1, p + 1)
            for i, key, label in keylabels.enumitems():
                this_color = colors[i]
                y = res[key].values
                pl.plot(res['t'], y, label=label, **plot_args, c=this_color)
                if key in data_mapping:
                    pl.scatter(self.data['day'],
                               data_mapping[key],
                               c=[this_color],
                               **scatter_args)
            if self.data is not None and len(self.data):
                pl.scatter(pl.nan,
                           pl.nan,
                           c=[(0, 0, 0)],
                           label='Data',
                           **scatter_args)

            pl.grid(use_grid)
            cvu.fixaxis(self)
            if use_commaticks:
                sc.commaticks()
            pl.title(title)

            # Optionally reset tick marks (useful for e.g. plotting weeks/months)
            if interval:
                xmin, xmax = ax.get_xlim()
                ax.set_xticks(pl.arange(xmin, xmax + 1, interval))

            # Set xticks as dates
            if as_dates:
                xticks = ax.get_xticks()
                xticklabels = self.inds2dates(xticks, dateformat=dateformat)
                ax.set_xticklabels(xticklabels)

            # Plot interventions
            for intervention in self['interventions']:
                intervention.plot(self, ax)

        # Ensure the figure actually renders or saves
        if do_save:
            if fig_path is None:  # No figpath provided - see whether do_save is a figpath
                if isinstance(do_save, str):
                    fig_path = do_save  # It's a string, assume it's a filename
                else:
                    fig_path = 'covasim.png'  # Just give it a default name
            fig_path = sc.makefilepath(
                fig_path)  # Ensure it's valid, including creating the folder
            pl.savefig(fig_path)

        if do_show:
            pl.show()
        else:
            pl.close(fig)

        return fig
示例#13
0
    def plot(self,
             do_save=None,
             fig_args=None,
             plot_args=None,
             scatter_args=None,
             fill_args=None,
             axis_args=None,
             font_size=None,
             font_family=None,
             use_grid=True,
             do_show=True,
             verbose=None):
        '''
        Plotting, copied from run_cdc_scenarios.

        Args:
            do_save (bool/str):  whether or not to save the figure, if so, to this filename
            fig_args (dict):     options for styling the figure (e.g. size)
            plot_args (dict):    likewise, for the plot (e.g., line thickness)
            scatter_args (dict): likewise, for scatter points (used for data)
            fill_args (dict):    likewise, for uncertainty bounds (e.g. alpha)
            axis_args (dict):    likewise, for axes (e.g. margins)
            font_size (int):     overall figure font size
            font_family (str):   what font to use (must exist on your system!)
            use_grid (bool):     whether or not to plot gridlines on the plot
            verbose (bool):      whether or not to print extra output

        Returns:
            fig: a matplotlib figure object
        '''

        if fig_args is None: fig_args = {'figsize': (16, 12)}
        if plot_args is None: plot_args = {'lw': 3, 'alpha': 0.7}
        if scatter_args is None: scatter_args = {'s': 150, 'marker': 's'}
        if axis_args is None:
            axis_args = {
                'left': 0.1,
                'bottom': 0.05,
                'right': 0.9,
                'top': 0.97,
                'wspace': 0.2,
                'hspace': 0.25
            }
        if fill_args is None: fill_args = {'alpha': 0.2}
        if font_size is None: font_size = 18
        if font_family is None: font_family = 'Proxima Nova'

        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        pl.rcParams['font.size'] = font_size
        pl.rcParams['font.family'] = font_family

        xmin = self.startday
        tvec = xmin + pl.arange(self.npts)  # TODO: fix! With dates!

        for rk, reskey in enumerate(self.reskeys):
            pl.subplot(len(self.reskeys), 1, rk + 1)

            resdata = self.beds[reskey]

            for scenkey, scendata in resdata.items():
                pl.fill_between(tvec, scendata.low, scendata.high, **fill_args)
                # pl.plot(tvec, scendata.low, linestyle='--', **plot_args)
                # pl.plot(tvec, scendata.high, linestyle='--', **plot_args)
                pl.plot(tvec,
                        scendata.best,
                        label=self.scenlabels[scenkey],
                        **plot_args)

                if rk == 0:
                    pl.legend()

                pl.title(self.reslabels[rk])
                # sc.setylim()
                pl.grid(True)

                # Set x-axis
                xmax = xmin + self.npts  # TODO: fix!!!
                pl.gca().set_xticks(pl.arange(xmin, xmax + 1, 7))
                xt = pl.gca().get_xticks()
                lab = []
                for t in xt:
                    tmp = dt.datetime(2020, 1, 1) + dt.timedelta(
                        days=int(t))  # + pars['day_0']
                    lab.append(tmp.strftime('%B %d'))
                pl.gca().set_xticklabels(lab)
                sc.commaticks(axis='y')

        if do_show:
            pl.show()

        return fig
示例#14
0
#%% Plotting

pl.figure(figsize=(18, 6), dpi=200)
# pl.rcParams['font.size'] = 18

pl.subplot(1, 3, 1)
colors = sc.vectocolor(pl.log10(popsizes), cmap='parula')
for k, key in enumerate(keys):
    label = f'{int(float(key[1:]))/1000}k: {results[k][-1]:0.0f}'
    pl.plot(results[k], label=label, lw=3, color=colors[k])
    print(label)
# pl.legend()
pl.title('Total number of infections')
pl.xlabel('Day')
pl.ylabel('Number of infections')
sc.commaticks(axis='y')

pl.subplot(1, 3, 2)
for k, key in enumerate(keys):
    label = f'{int(float(key[1:]))/1000}k: {results[k][-1]/popsizes[k]*100:0.1f}'
    pl.plot(results[k] / popsizes[k] * 100, label=label, lw=3, color=colors[k])
    print(label)
# pl.legend()
pl.title('Attack rate')
pl.xlabel('Day')
pl.ylabel('Attack rate (%)')

pl.subplot(1, 3, 3)
fres = [res[-1] for res in results]
pl.scatter(popsizes, fres, s=150, c=colors)
pl.title('Correlation')
示例#15
0
    for betakey in betadict.keys():
        for scenkey in scenkeys:
            tvec = sims[0].results['t']
            average = np.zeros(sims[0].npts)
            for seed in np.arange(n_seeds):
                label = f'{betakey}_{scenkey}_{seed}'
                res = sims[label].results
                cum_infections = res['cum_infections'].values
                a[betakey].plot(tvec, cum_infections, c=colors[scenkey], lw=3, alpha=0.2)
                average += cum_infections/n_seeds
            a[betakey].plot(tvec, average, '--', c=colors[scenkey], lw=3, alpha=1.0, zorder=10, label=scenmap[scenkey])

    for betakey in betadict.keys():
        ax = a[betakey]
        sc.commaticks(ax=ax)
        sc.setylim(ax=ax)
        ax.set_xlim([0,150])
        ax.set_ylabel('Cumulative infections')
        betanorm = betadict[betakey]*100*3*0.78 # Normalized to match Fig. 1 beta
        ax.set_title(rf'{betamap[betakey]} transmission, $\beta$ = {betanorm:0.1f}%')
        sc.boxoff(ax=ax)
        if betakey == 'best':
            ax.legend(frameon=False)
        if betakey == 'high':
            ax.set_yticks(np.arange(8)*2e3)
            ax.set_xlabel('Days since seed infections')

    return fig