예제 #1
0
def test_printing():
    ''' Test printing functions '''
    example = sc.prettyobj()
    example.data = sc.vectocolor(10)
    print('sc.pr():')
    sc.pr(example)
    print('sc.pp():')
    sc.pp(example.data)
    string = sc.pp(example.data, doprint=False)
    return string
예제 #2
0
def test_printing():
    sc.heading('test_printing()')
    example = sc.prettyobj()
    example.data = sc.vectocolor(10)
    print('sc.pr():')
    sc.pr(example)
    print('sc.pp():')
    sc.pp(example.data)
    string = sc.pp(example.data, doprint=False)
    print('sc.printdata():')
    sc.printdata(example.data)
    return string
예제 #3
0
파일: app.py 프로젝트: sciris/scirisweb
def computation(seed=0, n=1000):
    
    # Make graph
    pl.seed(int(seed))
    fig = pl.figure()
    ax = fig.add_subplot(111)
    xdata = pl.randn(n)
    ydata = pl.randn(n)
    colors = sc.vectocolor(pl.sqrt(xdata**2+ydata**2))
    ax.scatter(xdata, ydata, c=colors)
    
    # Convert to FE
    graphjson = sw.mpld3ify(fig, jsonify=False)  # Convert to dict
    return graphjson  # Return the JSON representation of the Matplotlib figure
예제 #4
0
def pairplotpars(df,
                 inds=None,
                 keys=None,
                 color_column=None,
                 bounds=None,
                 cmap='parula',
                 bins=None,
                 edgecolor='w',
                 facecolor='#F8A493',
                 figsize=(20, 16)):
    ''' Plot scatterplots, histograms, and kernel densities '''

    if inds is not None:
        df = df.iloc[inds, :].copy()

    # Choose the colors
    if color_column:
        colors = sc.vectocolor(df[color_column].values, cmap=cmap)
    else:
        colors = [facecolor for i in range(len(df))]
    df['color_column'] = [sc.rgb2hex(rgba[:-1]) for rgba in colors]

    if keys is not None:
        df = df.loc[:, keys + ['color_column']].copy()

    # Make the plot
    grid = sns.PairGrid(df)
    grid = grid.map_lower(pl.scatter, **{'facecolors': df['color_column']})
    grid = grid.map_diag(pl.hist,
                         bins=bins,
                         edgecolor=edgecolor,
                         facecolor=facecolor)
    grid = grid.map_upper(pl.hexbin,
                          cmap="Blues",
                          edgecolor="none",
                          gridsize=25)
    grid.fig.set_size_inches(figsize)
    grid.fig.tight_layout()

    # Set bounds
    if bounds:
        for ax in grid.axes.flatten():
            xlabel = ax.get_xlabel()
            ylabel = ax.get_ylabel()
            if xlabel in bounds:
                ax.set_xlim(bounds[xlabel])
            if ylabel in bounds:
                ax.set_ylim(bounds[ylabel])

    return grid
예제 #5
0
def pairplotpars(data,
                 inds=None,
                 color_column=None,
                 bounds=None,
                 cmap='parula',
                 bins=None,
                 edgecolor='w',
                 facecolor='#F8A493',
                 figsize=(20, 16)):
    ''' Plot scatterplots, histograms, and kernel densities '''
    import seaborn as sns  # Optional import

    data = sc.odict(sc.dcp(data))

    # Create the dataframe
    df = pd.DataFrame.from_dict(data)
    if inds is not None:
        df = df.iloc[inds, :].copy()

    # Choose the colors
    if color_column:
        colors = sc.vectocolor(df[color_column].values, cmap=cmap)
    else:
        colors = [facecolor for i in range(len(df))]
    df['color_column'] = [sc.rgb2hex(rgba[:-1]) for rgba in colors]

    # Make the plot
    grid = sns.PairGrid(df)
    grid = grid.map_lower(pl.scatter, **{'facecolors': df['color_column']})
    grid = grid.map_diag(pl.hist,
                         bins=bins,
                         edgecolor=edgecolor,
                         facecolor=facecolor)
    grid = grid.map_upper(sns.kdeplot)
    grid.fig.set_size_inches(figsize)
    grid.fig.tight_layout()

    # Set bounds
    if bounds:
        for ax in grid.axes.flatten():
            xlabel = ax.get_xlabel()
            ylabel = ax.get_ylabel()
            if xlabel in bounds:
                ax.set_xlim(bounds[xlabel])
            if ylabel in bounds:
                ax.set_ylim(bounds[ylabel])

    return grid
예제 #6
0
import pylab as pl
import sciris as sc
import hiptool as hp

sc.heading('Initializing...')
sc.tic()

dosave = True
missing_data = ['remove', 'assumption'][1] # Choose how to handle missing data
spendings = [0.1, 0.3, 1, 3, 10, 30, 100, 300, 1000]#, 3000, 10000]
nspendings = len(spendings)
colors = sc.vectocolor(len(spendings))

# Load input files
D = sc.odict() # Data
R = sc.odict() # Results
bod_data = sc.loadobj('gbd-data.dat')
country_data = sc.loadspreadsheet('country-data.xlsx')
baseline_factor = country_data.findrow('Zambia', asdict=True)['icer_multiplier'] # Zambia was used for this

# Create default
P = hp.Project()
P.loadburden(filename='rapid_BoD.xlsx')
P.loadinterventions(filename='rapid_interventions.xlsx')
ninterventions = P.intervsets[0].data.nrows

# Load data
missing_data_adjustment_factor = 2
sc.heading('Loading data...')
for c,country in enumerate(country_data['name'].tolist()):
    print(f'  Working on {country}...')
def plot():
    fig = pl.figure(num='Fig. 2: Transmission dynamics', figsize=(20,14))
    piey, tsy, r3y = 0.68, 0.50, 0.07
    piedx, tsdx, r3dx = 0.2, 0.9, 0.25
    piedy, tsdy, r3dy = 0.2, 0.47, 0.35
    pie1x, pie2x = 0.12, 0.65
    tsx = 0.07
    dispx, cumx, sympx = tsx, 0.33+tsx, 0.66+tsx
    ts_ax   = pl.axes([tsx, tsy, tsdx, tsdy])
    pie_ax1 = pl.axes([pie1x, piey, piedx, piedy])
    pie_ax2 = pl.axes([pie2x, piey, piedx, piedy])
    symp_ax = pl.axes([sympx, r3y, r3dx, r3dy])
    disp_ax = pl.axes([dispx, r3y, r3dx, r3dy])
    cum_ax  = pl.axes([cumx, r3y, r3dx, r3dy])

    off = 0.06
    txtdispx, txtcumx, txtsympx = dispx-off, cumx-off, sympx-off+0.02
    tsytxt = tsy+tsdy
    r3ytxt = r3y+r3dy
    labelsize = 40-wf
    pl.figtext(txtdispx, tsytxt, 'a', fontsize=labelsize)
    pl.figtext(txtdispx, r3ytxt, 'b', fontsize=labelsize)
    pl.figtext(txtcumx,  r3ytxt, 'c', fontsize=labelsize)
    pl.figtext(txtsympx, r3ytxt, 'd', fontsize=labelsize)


    #%% Fig. 2A -- Time series plot

    layer_keys = list(sim.layer_keys())
    layer_mapping = {k:i for i,k in enumerate(layer_keys)}
    n_layers = len(layer_keys)
    colors = sc.gridcolors(n_layers)

    layer_counts = np.zeros((sim.npts, n_layers))
    for source_ind, target_ind in tt.count_transmissions():
        dd = tt.detailed[target_ind]
        date = dd['date']
        layer_num = layer_mapping[dd['layer']]
        layer_counts[date, layer_num] += sim.rescale_vec[date]

    mar12 = cv.date('2020-03-12')
    mar23 = cv.date('2020-03-23')
    mar12d = sim.day(mar12)
    mar23d = sim.day(mar23)

    labels = ['Household', 'School', 'Workplace', 'Community', 'LTCF']
    for l in range(n_layers):
        ts_ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l])
    sc.setylim(ax=ts_ax)
    sc.boxoff(ax=ts_ax)
    ts_ax.set_ylabel('Transmissions per day')
    ts_ax.set_xlim([sc.readdate('2020-01-18'), sc.readdate('2020-06-09')])
    ts_ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))
    ts_ax.set_xticks([sim.date(d, as_date=True) for d in np.arange(0, sim.day('2020-06-09'), 14)])
    ts_ax.legend(frameon=False, bbox_to_anchor=(0.85,0.1))

    color = [0.2, 0.2, 0.2]
    ts_ax.axvline(mar12, c=color, linestyle='--', alpha=0.4, lw=3)
    ts_ax.axvline(mar23, c=color, linestyle='--', alpha=0.4, lw=3)
    yl = ts_ax.get_ylim()
    labely = yl[1]*1.015
    ts_ax.text(mar12, labely, 'Schools close                     ', color=color, alpha=0.9, style='italic', horizontalalignment='center')
    ts_ax.text(mar23, labely, '                   Stay-at-home', color=color, alpha=0.9, style='italic', horizontalalignment='center')


    #%% Fig. 2A inset -- Pie charts

    pre_counts = layer_counts[0:mar12d, :].sum(axis=0)
    post_counts = layer_counts[mar23d:, :].sum(axis=0)
    pre_counts = pre_counts/pre_counts.sum()*100
    post_counts = post_counts/post_counts.sum()*100

    lpre = [
        f'Household\n{pre_counts[0]:0.1f}%',
        f'School\n{pre_counts[1]:0.1f}% ',
        f'Workplace\n{pre_counts[2]:0.1f}%    ',
        f'Community\n{pre_counts[3]:0.1f}%',
        f'LTCF\n{pre_counts[4]:0.1f}%',
    ]

    lpost = [
        f'Household\n{post_counts[0]:0.1f}%',
        f'School\n{post_counts[1]:0.1f}%',
        f'Workplace\n{post_counts[2]:0.1f}%',
        f'Community\n{post_counts[3]:0.1f}%',
        f'LTCF\n{post_counts[4]:0.1f}%',
    ]

    pie_ax1.pie(pre_counts, colors=colors, labels=lpre, **pieargs)
    pie_ax2.pie(post_counts, colors=colors, labels=lpost, **pieargs)

    pie_ax1.text(0, 1.75, 'Transmissions by layer\nbefore schools closed', style='italic', horizontalalignment='center')
    pie_ax2.text(0, 1.75, 'Transmissions by layer\nafter stay-at-home', style='italic', horizontalalignment='center')


    #%% Fig. 2B -- histogram by overdispersion

    # Process targets
    n_targets = tt.count_targets(end_day=mar12)

    # Handle bins
    max_infections = n_targets.max()
    edges = np.arange(0, max_infections+2)

    # Analysis
    counts = np.histogram(n_targets, edges)[0]
    bins = edges[:-1] # Remove last bin since it's an edge
    norm_counts = counts/counts.sum()
    raw_counts = counts*bins
    total_counts = raw_counts/raw_counts.sum()*100
    n_bins = len(bins)
    index = np.linspace(0, 100, len(n_targets))
    sorted_arr = np.sort(n_targets)
    sorted_sum = np.cumsum(sorted_arr)
    sorted_sum = sorted_sum/sorted_sum.max()*100
    change_inds = sc.findinds(np.diff(sorted_arr) != 0)

    pl.set_cmap('Spectral_r')
    sscolors = sc.vectocolor(n_bins)

    width = 1.0
    for i in range(n_bins):
        disp_ax.bar(bins[i], total_counts[i], width=width, facecolor=sscolors[i])
    disp_ax.set_xlabel('Number of transmissions per case')
    disp_ax.set_ylabel('Proportion of transmissions (%)')
    sc.boxoff()
    disp_ax.set_xlim([0.5, 32.5])
    disp_ax.set_xticks(np.arange(0, 32.5, 4))
    sc.boxoff(ax=disp_ax)

    dpie_ax = pl.axes([dispx+0.05, 0.20, 0.2, 0.2])
    trans1 = total_counts[1:3].sum()
    trans2 = total_counts[3:5].sum()
    trans3 = total_counts[5:8].sum()
    trans4 = total_counts[8:].sum()
    labels = [
        f'1-2:\n{trans1:0.0f}%',
        f' 3-4:\n {trans2:0.0f}%',
        f'5-7: \n{trans3:0.0f}%\n',
        f'>7:  \n{trans4:0.0f}%\n',
        ]
    dpie_args = sc.mergedicts(pieargs, dict(labeldistance=1.2)) # Slightly smaller label distance
    dpie_ax.pie([trans1, trans2, trans3, trans4], labels=labels, colors=sscolors[[0,4,7,12]], **dpie_args)


    #%% Fig. 2C -- cumulative distribution function

    rev_ind = 100 - index
    n_change_inds = len(change_inds)
    change_bins = bins[counts>0][1:]
    for i in range(n_change_inds):
        ib = int(change_bins[i])
        ci = change_inds[i]
        ici = index[ci]
        sci = sorted_sum[ci]
        color = sscolors[ib]
        if i>0:
            cim1 = change_inds[i-1]
            icim1 = index[cim1]
            scim1 = sorted_sum[cim1]
            cum_ax.plot([icim1, ici], [scim1, sci], lw=4, c=color)
        cum_ax.scatter([ici], [sci], s=150, zorder=50-i, c=[color], edgecolor='w', linewidth=0.2)
        if ib<=6 or ib in [8, 10, 25]:
            xoff = 5 - 2*(ib==1) + 3*(ib>=10) + 1*(ib>=20)
            yoff = 2*(ib==1)
            cum_ax.text(ici-xoff, sci+yoff, ib, fontsize=18-wf, color=color)
    cum_ax.set_xlabel('Proportion of primary infections (%)')
    cum_ax.set_ylabel('Proportion of transmissions (%)')
    xmin = -2
    ymin = -2
    cum_ax.set_xlim([xmin, 102])
    cum_ax.set_ylim([ymin, 102])
    sc.boxoff(ax=cum_ax)

    # Draw horizontal lines and annotations
    ancol1 = [0.2, 0.2, 0.2]
    ancol2 = sscolors[0]
    ancol3 = sscolors[6]

    i01 = sc.findlast(sorted_sum==0)
    i20 = sc.findlast(sorted_sum<=20)
    i50 = sc.findlast(sorted_sum<=50)
    cum_ax.plot([xmin, index[i01]], [0, 0], '--', lw=2, c=ancol1)
    cum_ax.plot([xmin, index[i20], index[i20]], [20, 20, ymin], '--', lw=2, c=ancol2)
    cum_ax.plot([xmin, index[i50], index[i50]], [50, 50, ymin], '--', lw=2, c=ancol3)

    # Compute mean number of transmissions for 80% and 50% thresholds
    q80 = sc.findfirst(np.cumsum(total_counts)>20) # Count corresponding to 80% of cumulative infections (100-80)
    q50 = sc.findfirst(np.cumsum(total_counts)>50) # Count corresponding to 50% of cumulative infections
    n80, n50 = [sum(bins[q:]*norm_counts[q:]/norm_counts[q:].sum()) for q in [q80, q50]]

    # Plot annotations
    kw = dict(bbox=dict(facecolor='w', alpha=0.9, lw=0), fontsize=20-wf)
    cum_ax.text(2, 3, f'{index[i01]:0.0f}% of infections\ndo not transmit', c=ancol1, **kw)
    cum_ax.text(8, 23, f'{rev_ind[i20]:0.0f}% of infections cause\n80% of transmissions\n(mean: {n80:0.1f} per infection)', c=ancol2, **kw)
    cum_ax.text(14, 53, f'{rev_ind[i50]:0.0f}% of infections cause\n50% of transmissions\n(mean: {n50:0.1f} per infection)', c=ancol3, **kw)


    #%% Fig. 2D -- histogram by date of symptom onset

    # Calculate
    asymp_count = 0
    symp_counts = {}
    minind = -5
    maxind = 15
    for _, target_ind in tt.transmissions:
        dd = tt.detailed[target_ind]
        date = dd['date']
        delta = sim.rescale_vec[date] # Increment counts by this much
        if dd['s']:
            if tt.detailed[dd['source']]['date'] <= date: # Skip dynamical scaling reinfections
                sdate = dd['s']['date_symptomatic']
                if np.isnan(sdate):
                    asymp_count += delta
                else:
                    ind = int(date - sdate)
                    if ind not in symp_counts:
                        symp_counts[ind] = 0
                    symp_counts[ind] += delta

    # Convert to an array
    xax = np.arange(minind-1, maxind+1)
    sympcounts = np.zeros(len(xax))
    for i,val in symp_counts.items():
        if i<minind:
            ind = 0
        elif i>maxind:
            ind = -1
        else:
            ind = sc.findinds(xax==i)[0]
        sympcounts[ind] += val

    # Plot
    total_count = asymp_count + sympcounts.sum()
    sympcounts = sympcounts/total_count*100
    presymp = sc.findinds(xax<=0)[-1]
    colors = ['#eed15b', '#ee943a', '#c3211a']

    asymp_frac = asymp_count/total_count*100
    pre_frac = sympcounts[:presymp].sum()
    symp_frac = sympcounts[presymp:].sum()
    symp_ax.bar(xax[0]-2, asymp_frac, label='Asymptomatic', color=colors[0])
    symp_ax.bar(xax[:presymp], sympcounts[:presymp], label='Presymptomatic', color=colors[1])
    symp_ax.bar(xax[presymp:], sympcounts[presymp:], label='Symptomatic', color=colors[2])
    symp_ax.set_xlabel('Days since symptom onset')
    symp_ax.set_ylabel('Proportion of transmissions (%)')
    symp_ax.set_xticks([minind-3, 0, 5, 10, maxind])
    symp_ax.set_xticklabels(['Asymp.', '0', '5', '10', f'>{maxind}'])
    sc.boxoff(ax=symp_ax)

    spie_ax = pl.axes([sympx+0.05, 0.20, 0.2, 0.2])
    labels = [f'Asymp-\ntomatic\n{asymp_frac:0.0f}%', f' Presymp-\n tomatic\n {pre_frac:0.0f}%', f'Symp-\ntomatic\n{symp_frac:0.0f}%']
    spie_ax.pie([asymp_frac, pre_frac, symp_frac], labels=labels, colors=colors, **pieargs)

    return fig
예제 #8
0
tts = sc.objdict()
for key, sim in sims.items():
    sim.run()
    sim.people.make_detailed_transtree()
    tts[key] = sim.people.transtree.detailed
    if plot_sim:
        to_plot = cv.get_sim_plots()
        to_plot['Total counts'] = [
            'cum_infections', 'cum_diagnoses', 'cum_quarantined',
            'n_quarantined'
        ]
        sim.plot(to_plot=to_plot)

#%% Plotting

colors = sc.vectocolor(sim.n, cmap='parula')

msize = 10
suscol = [0.5, 0.5, 0.5]
plargs = dict(lw=2, alpha=0.5)
idelay = 0.05
daydelay = 0.3
pl.rcParams['font.size'] = 18

F = sc.objdict()
T = sc.objdict()
D = sc.objdict()
Q = sc.objdict()

for key in sims.keys():
예제 #9
0
for ind, person in datadict.items():
    mapping[person.Case] = ind

contacts = {}
for ind in datadict.keys():
    contacts[ind] = []

for ind, person in datadict.items():
    if person.exposure_source != -1:
        contacts[mapping[person.exposure_source]].append(ind)

# Handle colors
ages_for_color = data.age.to_numpy()
for i, person in enumerate(datadict.values()):
    assert ages_for_color[i] == person.age
colors = sc.vectocolor(ages_for_color, cmap=sc.parulacolormap())

# Process into days
plotdata = {}
for i in range(N):
    row = data.loc[i]
    if row.day not in plotdata:
        plotdata[row.day] = []
    plotdata[row.day].append(row)

#%% Plotting

# Plot settings
delta = 1.05
yoff = 0.01
markersize = 300
예제 #10
0
def animate_transtree(tt, **kwargs):
    ''' Plot an animation of the transmission tree; see TransTree.animate() for documentation '''

    # Settings
    animate = kwargs.get('animate', True)
    verbose = kwargs.get('verbose', False)
    msize = kwargs.get('markersize', 10)
    sus_color = kwargs.get('sus_color', [0.5, 0.5, 0.5])
    fig_args = kwargs.get('fig_args', dict(figsize=(24, 16)))
    axis_args = kwargs.get(
        'axis_args',
        dict(left=0.10,
             bottom=0.05,
             right=0.85,
             top=0.97,
             wspace=0.25,
             hspace=0.25))
    plot_args = kwargs.get('plot_args', dict(lw=2, alpha=0.5))
    delay = kwargs.get('delay', 0.2)
    font_size = kwargs.get('font_size', 18)
    colors = kwargs.get('colors', None)
    cmap = kwargs.get('cmap', 'parula')
    pl.rcParams['font.size'] = font_size
    if colors is None:
        colors = sc.vectocolor(tt.pop_size, cmap=cmap)

    # Initialization
    n = tt.n_days + 1
    frames = [list() for i in range(n)]
    tests = [list() for i in range(n)]
    diags = [list() for i in range(n)]
    quars = [list() for i in range(n)]

    # Construct each frame of the animation
    for i, entry in enumerate(tt.detailed):  # Loop over every person
        frame = sc.objdict()
        tdq = sc.objdict()  # Short for "tested, diagnosed, or quarantined"

        # This person became infected
        if entry:
            source = entry['source']
            target = entry['target']
            target_date = entry['date']
            if source:  # Seed infections and importations won't have a source
                source_date = tt.detailed[source]['date']
            else:
                source = 0
                source_date = 0

            # Construct this frame
            frame.x = [source_date, target_date]
            frame.y = [source, target]
            frame.c = colors[source]
            frame.i = True  # If this person is infected
            frames[target_date].append(frame)

            # Handle testing, diagnosis, and quarantine
            tdq.t = target
            tdq.d = target_date
            tdq.c = colors[target]
            date_t = entry.t.date_tested
            date_d = entry.t.date_diagnosed
            date_q = entry.t.date_known_contact
            if ~np.isnan(date_t) and date_t < n: tests[int(date_t)].append(tdq)
            if ~np.isnan(date_d) and date_d < n: diags[int(date_d)].append(tdq)
            if ~np.isnan(date_q) and date_q < n: quars[int(date_q)].append(tdq)

        # This person did not become infected
        else:
            frame.x = [0]
            frame.y = [i]
            frame.c = sus_color
            frame.i = False
            frames[0].append(frame)

    # Configure plotting
    fig = pl.figure(**fig_args)
    pl.subplots_adjust(**axis_args)
    ax = fig.add_subplot(1, 1, 1)

    # Create the legend
    ax2 = pl.axes([0.85, 0.05, 0.14, 0.9])
    ax2.axis('off')
    lcol = colors[0]
    na = np.nan  # Shorten
    pl.plot(na, na, '-', c=lcol, **plot_args, label='Transmission')
    pl.plot(na, na, 'o', c=lcol, markersize=msize, **plot_args, label='Source')
    pl.plot(na, na, '*', c=lcol, markersize=msize, **plot_args, label='Target')
    pl.plot(na,
            na,
            'o',
            c=lcol,
            markersize=msize * 2,
            fillstyle='none',
            **plot_args,
            label='Tested')
    pl.plot(na,
            na,
            's',
            c=lcol,
            markersize=msize * 1.2,
            **plot_args,
            label='Diagnosed')
    pl.plot(na, na, 'x', c=lcol, markersize=msize * 2.0, label='Known contact')
    pl.legend()

    # Plot the animation
    pl.sca(ax)
    for day in range(n):
        pl.title(f'Day: {day}')
        pl.xlim([0, n])
        pl.ylim([0, tt.pop_size])
        pl.xlabel('Day')
        pl.ylabel('Person')
        flist = frames[day]
        tlist = tests[day]
        dlist = diags[day]
        qlist = quars[day]
        if verbose: print(i, flist)
        for f in flist:
            if verbose: print(f)
            pl.plot(f.x[0], f.y[0], 'o', c=f.c, markersize=msize,
                    **plot_args)  # Plot sources
            pl.plot(f.x, f.y, '-', c=f.c,
                    **plot_args)  # Plot transmission lines
            if f.i:  # If this person is infected
                pl.plot(f.x[1],
                        f.y[1],
                        '*',
                        c=f.c,
                        markersize=msize,
                        **plot_args)  # Plot targets
        for tdq in tlist:
            pl.plot(tdq.d,
                    tdq.t,
                    'o',
                    c=tdq.c,
                    markersize=msize * 2,
                    fillstyle='none')  # Tested; No alpha for this
        for tdq in dlist:
            pl.plot(tdq.d,
                    tdq.t,
                    's',
                    c=tdq.c,
                    markersize=msize * 1.2,
                    **plot_args)  # Diagnosed
        for tdq in qlist:
            pl.plot(tdq.d, tdq.t, 'x', c=tdq.c,
                    markersize=msize * 2.0)  # Quarantine; no alpha for this
        pl.plot([0, day], [0.5, 0.5], c='k',
                lw=5)  # Plot the endless march of time
        if animate:  # Whether to animate
            pl.pause(delay)

    return fig
예제 #11
0
def test_plot_pop():

    plotconnections = True

    doclear = False
    pause = 0.2
    n = 20000
    alpha = 0.5

    # indices = pl.arange(1000)
    pl.seed(1)
    indices = pl.randint(0, n, 20)

    max_contacts = {'S': 20, 'W': 10}
    population = sp.make_population(n=n,
                                    max_contacts=max_contacts,
                                    as_objdict=True)

    nside = np.ceil(np.sqrt(n))
    x, y = np.meshgrid(np.arange(nside), np.arange(nside))
    x = x.flatten()[:n]
    y = y.flatten()[:n]

    people = population.values()
    for p, person in enumerate(people):
        person.loc = sc.objdict(dict(x=x[p], y=y[p]))
    ages = np.array([person.age for person in people])
    f_inds = [ind for ind, person in enumerate(people) if not person.sex]
    m_inds = [ind for ind, person in enumerate(people) if person.sex]

    # import matplotlib.pyplot as plt
    # import matplotlib.colors as colors
    # colors_undersea = plt.cm.terrain(np.linspace(0, 0.17, 256))
    # colors_land = plt.cm.terrain(np.linspace(0.25, 1, 256))
    # all_colors = np.vstack((colors_undersea, colors_land))
    # terrain_map = colors.LinearSegmentedColormap.from_list('terrain_map',
    #     all_colors)
    # pl.set_cmap(terrain_map)

    fig = pl.figure(figsize=(24, 18))
    ax = pl.subplot(111)
    # sc.turbocolormap(apply=True)
    minval = 0  # ages.min()
    maxval = 100  # ages.min()
    colors = sc.vectocolor(ages, minval=minval, maxval=maxval)
    for i, inds in enumerate([f_inds, m_inds]):
        pl.scatter(x[inds], y[inds], marker='os'[i], c=colors[inds])
    pl.clim([minval, maxval])
    pl.colorbar()

    if plotconnections:
        lcols = dict(H=[0, 0, 0], S=[0, 0.5, 1], W=[0, 0.7, 0], C=[1, 1, 0])
        for index in indices:
            person = people[index]
            contacts = person.contacts
            lines = []
            for lkey in lcols.keys():
                for contactkey in contacts[lkey]:
                    contact = population[contactkey]
                    tmp = pl.plot([person.loc.x, contact.loc.x],
                                  [person.loc.y, contact.loc.y],
                                  c=lcols[lkey],
                                  alpha=alpha)
                    lines.append(tmp)
            # pl.title(f'Index: {index}')
            # pl.pause(pause)
            # if doclear:
            # ax.lines = []

    return fig
예제 #12
0
    sc.colorize(showhelp=True)
    sc.colorize('green', 'hi')  # Simple example
    sc.colorize(['yellow', 'bgblack'])
    print('Hello world')
    print('Goodbye world')
    sc.colorize('reset')  # Colorize all output in between
    bluearray = sc.colorize(color='blue', string=str(range(5)), output=True)
    print("c'est bleu: " + bluearray)
    sc.colorize('magenta')  # Now type in magenta for a while
    print('this is magenta')
    sc.colorize('reset')  # Stop typing in magenta

# Test printing functions
if 'printing' in torun:
    example = sc.prettyobj()
    example.data = sc.vectocolor(10)
    print('sc.pr():')
    sc.pr(example)
    print('sc.pp():')
    sc.pp(example.data)
    string = sc.pp(example.data, doprint=False)

# Test profiling functions
if 'profile' in torun:

    def slow_fn():
        n = 10000
        int_list = []
        int_dict = {}
        for i in range(n):
            int_list.append(i)
예제 #13
0
        n_days=30,
        pop_infected=100,
        rand_seed=25857 + p * 241,  #, 29837*(p+298),
        pop_type='random',
        verbose=0,
    )
    sims[key].run()
    results.append(sims[key].results['cum_infections'].values)

#%% Plotting

pl.figure(figsize=(18, 6), dpi=200)
# pl.rcParams['font.size'] = 18

pl.subplot(1, 3, 1)
colors = sc.vectocolor(pl.log10(popsizes), cmap='parula')
for k, key in enumerate(keys):
    label = f'{int(float(key[1:]))/1000}k: {results[k][-1]:0.0f}'
    pl.plot(results[k], label=label, lw=3, color=colors[k])
    print(label)
# pl.legend()
pl.title('Total number of infections')
pl.xlabel('Day')
pl.ylabel('Number of infections')
sc.commaticks(axis='y')

pl.subplot(1, 3, 2)
for k, key in enumerate(keys):
    label = f'{int(float(key[1:]))/1000}k: {results[k][-1]/popsizes[k]*100:0.1f}'
    pl.plot(results[k] / popsizes[k] * 100, label=label, lw=3, color=colors[k])
    print(label)
예제 #14
0
    def plot_histogram(self,
                       bins=None,
                       fig_args=None,
                       width=0.8,
                       font_size=18):
        ''' Plots a histogram of the number of transmissions '''
        if bins is None:
            max_infections = self.n_targets.max()
            bins = np.arange(0, max_infections + 2)

        # Analysis
        counts = np.histogram(self.n_targets, bins)[0]

        bins = bins[:-1]  # Remove last bin since it's an edge
        total_counts = counts * bins
        # counts = counts*100/counts.sum()
        # total_counts = total_counts*100/total_counts.sum()
        n_bins = len(bins)
        n_trans = sum(total_counts)
        index = np.linspace(0, 100, len(self.n_targets))
        sorted_arr = np.sort(self.n_targets)
        sorted_sum = np.cumsum(sorted_arr)
        sorted_sum = sorted_sum / sorted_sum.max() * 100
        change_inds = sc.findinds(np.diff(sorted_arr) != 0)

        # Plotting
        fig_args = sc.mergedicts(dict(figsize=(24, 15)))
        pl.rcParams['font.size'] = font_size
        fig = pl.figure(**fig_args)
        pl.set_cmap('Spectral')
        pl.subplots_adjust(left=0.08, right=0.92, bottom=0.08, top=0.92)
        colors = sc.vectocolor(n_bins)

        pl.subplot(1, 2, 1)
        w05 = width * 0.5
        w025 = w05 * 0.5
        pl.bar(bins - w025,
               counts,
               width=w05,
               facecolor='k',
               label='Number of events')
        for i in range(n_bins):
            label = 'Number of transmissions (events × transmissions per event)' if i == 0 else None
            pl.bar(bins[i] + w025,
                   total_counts[i],
                   width=w05,
                   facecolor=colors[i],
                   label=label)
        pl.xlabel('Number of transmissions per person')
        pl.ylabel('Count')
        pl.xticks(ticks=bins)
        pl.legend()
        pl.title('Numbers of events and transmissions')

        pl.subplot(2, 2, 2)
        total = 0
        for i in range(n_bins):
            new = total_counts[i] / n_trans * 100
            pl.bar(bins[i:],
                   new,
                   width=width,
                   bottom=total,
                   facecolor=colors[i])
            total += new
        pl.xticks(ticks=bins)
        pl.xlabel('Number of transmissions per person')
        pl.ylabel('Proportion of infections caused (%)')
        pl.title('Proportion of transmissions, by number of transmissions')

        pl.subplot(2, 2, 4)
        pl.plot(index, sorted_sum, lw=3, c='k', alpha=0.5)
        for i in range(len(change_inds)):
            pl.scatter([index[change_inds[i]]], [sorted_sum[change_inds[i]]],
                       s=150,
                       zorder=10,
                       c=[colors[i]],
                       label=f'Transmitted to {i+1} people')
        pl.xlabel(
            'Proportion of population, ordered by the number of people they infected (%)'
        )
        pl.ylabel('Proportion of infections caused (%)')
        pl.legend()
        pl.ylim([0, 100])
        pl.title('Proportion of transmissions, by proportion of population')

        pl.axes([0.25, 0.65, 0.2, 0.2])
        berry = [0.8, 0.1, 0.2]
        pl.plot(self.sim_results.t,
                self.sim_results.cum_infections,
                lw=2,
                c=berry)
        pl.xlabel('Day')
        pl.ylabel('Cumulative infections')

        return fig
예제 #15
0
    def animate(self, *args, **kwargs):
        '''
        Animate the transmission tree.

        Args:
            animate    (bool):  whether to animate the plot (otherwise, show when finished)
            verbose    (bool):  print out progress of each frame
            markersize (int):   size of the markers
            sus_color  (list):  color for susceptibles
            fig_args   (dict):  arguments passed to pl.figure()
            axis_args  (dict):  arguments passed to pl.subplots_adjust()
            plot_args  (dict):  arguments passed to pl.plot()
            delay      (float): delay between frames in seconds
            font_size  (int):   size of the font
            colors     (list):  color of each person
            cmap       (str):   colormap for each person (if colors is not supplied)

        Returns:
            fig: the figure object
        '''

        # Settings
        animate = kwargs.get('animate', True)
        verbose = kwargs.get('verbose', False)
        msize = kwargs.get('markersize', 10)
        sus_color = kwargs.get('sus_color', [0.5, 0.5, 0.5])
        fig_args = kwargs.get('fig_args', dict(figsize=(24, 16)))
        axis_args = kwargs.get(
            'axis_args',
            dict(left=0.10,
                 bottom=0.05,
                 right=0.85,
                 top=0.97,
                 wspace=0.25,
                 hspace=0.25))
        plot_args = kwargs.get('plot_args', dict(lw=2, alpha=0.5))
        delay = kwargs.get('delay', 0.2)
        font_size = kwargs.get('font_size', 18)
        colors = kwargs.get('colors', None)
        cmap = kwargs.get('cmap', 'parula')
        pl.rcParams['font.size'] = font_size
        if colors is None:
            colors = sc.vectocolor(self.pop_size, cmap=cmap)

        # Initialization
        n = self.n_days + 1
        frames = [list() for i in range(n)]
        tests = [list() for i in range(n)]
        diags = [list() for i in range(n)]
        quars = [list() for i in range(n)]

        # Construct each frame of the animation
        for ddict in self.detailed:  # Loop over every person
            if ddict is None:
                continue  # Skip the 'None' node corresponding to seeded infections

            frame = sc.objdict()
            tdq = sc.objdict()  # Short for "tested, diagnosed, or quarantined"
            target = ddict.t
            target_ind = ddict['target']

            if not np.isnan(ddict['date']):  # If this person was infected

                source_ind = ddict[
                    'source']  # Index of the person who infected the target

                target_date = ddict['date']
                if source_ind is not None:  # Seed infections and importations won't have a source
                    source_date = self.detailed[source_ind]['date']
                else:
                    source_ind = 0
                    source_date = 0

                # Construct this frame
                frame.x = [source_date, target_date]
                frame.y = [source_ind, target_ind]
                frame.c = colors[source_ind]
                frame.i = True  # If this person is infected
                frames[int(target_date)].append(frame)

                # Handle testing, diagnosis, and quarantine
                tdq.t = target_ind
                tdq.d = target_date
                tdq.c = colors[int(target_ind)]
                date_t = target['date_tested']
                date_d = target['date_diagnosed']
                date_q = target['date_known_contact']
                if ~np.isnan(date_t) and date_t < n:
                    tests[int(date_t)].append(tdq)
                if ~np.isnan(date_d) and date_d < n:
                    diags[int(date_d)].append(tdq)
                if ~np.isnan(date_q) and date_q < n:
                    quars[int(date_q)].append(tdq)

            else:
                frame.x = [0]
                frame.y = [target_ind]
                frame.c = sus_color
                frame.i = False
                frames[0].append(frame)

        # Configure plotting
        fig = pl.figure(**fig_args)
        pl.subplots_adjust(**axis_args)
        ax = fig.add_subplot(1, 1, 1)

        # Create the legend
        ax2 = pl.axes([0.85, 0.05, 0.14, 0.9])
        ax2.axis('off')
        lcol = colors[0]
        na = np.nan  # Shorten
        pl.plot(na, na, '-', c=lcol, **plot_args, label='Transmission')
        pl.plot(na,
                na,
                'o',
                c=lcol,
                markersize=msize,
                **plot_args,
                label='Source')
        pl.plot(na,
                na,
                '*',
                c=lcol,
                markersize=msize,
                **plot_args,
                label='Target')
        pl.plot(na,
                na,
                'o',
                c=lcol,
                markersize=msize * 2,
                fillstyle='none',
                **plot_args,
                label='Tested')
        pl.plot(na,
                na,
                's',
                c=lcol,
                markersize=msize * 1.2,
                **plot_args,
                label='Diagnosed')
        pl.plot(na,
                na,
                'x',
                c=lcol,
                markersize=msize * 2.0,
                label='Known contact')
        pl.legend()

        # Plot the animation
        pl.sca(ax)
        for day in range(n):
            pl.title(f'Day: {day}')
            pl.xlim([0, n])
            pl.ylim([0, len(self)])
            pl.xlabel('Day')
            pl.ylabel('Person')
            flist = frames[day]
            tlist = tests[day]
            dlist = diags[day]
            qlist = quars[day]
            for f in flist:
                if verbose: print(f)
                pl.plot(f.x[0],
                        f.y[0],
                        'o',
                        c=f.c,
                        markersize=msize,
                        **plot_args)  # Plot sources
                pl.plot(f.x, f.y, '-', c=f.c,
                        **plot_args)  # Plot transmission lines
                if f.i:  # If this person is infected
                    pl.plot(f.x[1],
                            f.y[1],
                            '*',
                            c=f.c,
                            markersize=msize,
                            **plot_args)  # Plot targets
            for tdq in tlist:
                pl.plot(tdq.d,
                        tdq.t,
                        'o',
                        c=tdq.c,
                        markersize=msize * 2,
                        fillstyle='none')  # Tested; No alpha for this
            for tdq in dlist:
                pl.plot(tdq.d,
                        tdq.t,
                        's',
                        c=tdq.c,
                        markersize=msize * 1.2,
                        **plot_args)  # Diagnosed
            for tdq in qlist:
                pl.plot(tdq.d, tdq.t, 'x', c=tdq.c, markersize=msize *
                        2.0)  # Quarantine; no alpha for this
            pl.plot([0, day], [0.5, 0.5], c='k',
                    lw=5)  # Plot the endless march of time
            if animate:  # Whether to animate
                pl.pause(delay)

        return fig
예제 #16
0
def plot_pop(do_show=False, pause=0.2):
    ''' Plot an example population '''

    plotconnections = True
    n = 5000
    alpha = 0.5

    # indices = pl.arange(1000)
    pl.seed(1)
    indices = pl.randint(0, n, 20)

    max_contacts = {'S': 20, 'W': 10}
    population = sp.make_population(n=n, max_contacts=max_contacts)

    nside = np.ceil(np.sqrt(n))
    x, y = np.meshgrid(np.arange(nside), np.arange(nside))
    x = x.flatten()[:n]
    y = y.flatten()[:n]

    people = list(population.values())
    for p, person in enumerate(people):
        person['loc'] = dict(x=x[p], y=y[p])
    ages = np.array([person['age'] for person in people])
    f_inds = [ind for ind, person in enumerate(people) if not person['sex']]
    m_inds = [ind for ind, person in enumerate(people) if person['sex']]

    if do_show:

        use_terrain = False
        if use_terrain:
            import matplotlib.pyplot as plt
            import matplotlib.colors as colors
            colors_undersea = plt.cm.terrain(np.linspace(0, 0.17, 256))
            colors_land = plt.cm.terrain(np.linspace(0.25, 1, 256))
            all_colors = np.vstack((colors_undersea, colors_land))
            terrain_map = colors.LinearSegmentedColormap.from_list(
                'terrain_map', all_colors)
            pl.set_cmap(terrain_map)

        fig = pl.figure(figsize=(24, 18))
        pl.subplot(111)
        minval = 0  # ages.min()
        maxval = 100  # ages.min()
        colors = sc.vectocolor(ages, minval=minval, maxval=maxval)
        for i, inds in enumerate([f_inds, m_inds]):
            pl.scatter(x[inds], y[inds], marker='os'[i], c=colors[inds])
        pl.clim([minval, maxval])
        pl.colorbar()

        if plotconnections:
            lcols = dict(H=[0, 0, 0],
                         S=[0, 0.5, 1],
                         W=[0, 0.7, 0],
                         C=[1, 1, 0])
            for index in indices:
                person = people[index]
                contacts = person['contacts']
                lines = []
                for lkey in lcols.keys():
                    for contactkey in contacts[lkey]:
                        contact = population[contactkey]
                        tmp = pl.plot(
                            [person['loc']['x'], contact['loc']['x']],
                            [person['loc']['y'], contact['loc']['y']],
                            c=lcols[lkey],
                            alpha=alpha)
                        lines.append(tmp)
                if pause:
                    pl.pause(pause)

        return fig
예제 #17
0
# transition_steps = 0  # heat+freeze
transition_steps = 4  # increase if closest ends of the color maps are far apart, values to try: 4, 8, 16
transition = mplt.colors.LinearSegmentedColormap.from_list(
    "transition", [cmap1(1.), cmap2(0)])(np.linspace(0, 1, transition_steps))
colors = np.vstack((colors1, transition, colors2))
colors = np.flipud(colors)

new_cmap = mplt.colors.LinearSegmentedColormap.from_list(new_cmap_name, colors)
cmap = new_cmap

# Assign colors to age groups
age_cutoffs = np.arange(
    0, 101, 10)  # np.array([0, 4, 6, 18, 22, 30, 45, 65, 80, 90, 100])
if discrete:
    raw_colors = sc.vectocolor(len(age_cutoffs), cmap=cmap)
    colors = []
    for age in sim.people.age:
        ind = sc.findinds(age_cutoffs <= age)[-1]
        colors.append(raw_colors[ind])
    colors = np.array(colors)
else:
    age_map = sim.people.age * 0.1 + np.sqrt(sim.people.age)
    colors = sc.vectocolor(age_map, cmap=cmap)

# Create the legend
if plot_stacked:
    ax = fig.add_axes([0.85, 0.05, 0.14, 0.93])
elif len(keys_to_plot) % 2 != 0:
    ax = fig.add_axes([0.85, 0.05, 0.14, 0.93])
else:
예제 #18
0
def plot_surface(ax, dfr, col=0, colval=0):
    ''' Plot one of the surfaces '''

    all_tests = dfr['cum_tests']
    quar_tests = cs.d_calcs.quar_test * dfr['cum_quarantined']
    non_quar_tests = all_tests - quar_tests
    scaled_nq_tests = non_quar_tests * 1000 / kcpop / n_days

    x = scaled_nq_tests
    y = np.array(dfr['trprob'])
    z = np.array(dfr['r_eff'])
    min_x = 0
    min_y = 0
    max_x = 6
    max_y = 1.0
    min_z = 0.25
    max_z = 1.75

    eps = 0.08
    npts = 100
    xi = np.linspace(min_x, max_x, npts)
    yi = np.linspace(min_y, max_y, npts)
    xx, yy = np.meshgrid(xi, yi)
    zz = gauss2d(x,
                 y,
                 z,
                 xi,
                 yi,
                 eps=eps,
                 xscale=max_x - min_x,
                 yscale=max_y - min_y)

    im = ax.contourf(xx,
                     yy,
                     zz,
                     cmap=colormap,
                     levels=np.linspace(min_z, max_z, 300))
    if col == 0:
        ax.set_ylabel('Contact tracing probability at home and work',
                      labelpad=20)
    ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

    if col == 1:
        ax.set_xlabel('Number of routine tests per 1,000 people per day')

    axhandle = ax.contour(xx,
                          yy,
                          zz,
                          levels=[1],
                          colors='k',
                          linestyles=':',
                          linewidths=2)
    if col == 0:
        ax.clabel(axhandle,
                  fmt=r'$R_{e}=%2.1d$',
                  colors='k',
                  fontsize=18,
                  manual=[(1, 0.3)])
    elif col == 1:
        ax.clabel(axhandle,
                  fmt=r'$R_{e}=%2.1d$',
                  colors='k',
                  fontsize=18,
                  manual=[(2, 0.4)])
    else:
        ax.clabel(axhandle,
                  fmt=r'$R_{e}=%2.1d$',
                  colors='k',
                  fontsize=18,
                  manual=[(3, 0.5)])

    ax.set_title('%i%% mobility' % (colval * 100))

    scolors = sc.vectocolor(z, cmap=colormap, minval=min_z, maxval=max_z)
    ax.scatter(x,
               y,
               marker='o',
               c=scolors,
               edgecolor=[0.3] * 3,
               s=15,
               linewidth=0.1,
               alpha=0.5)  # Add the measles plots

    ax.set_xlim([min_x, max_x])
    ax.set_ylim([min_y, max_y])
    if verbose:
        print(
            f'Plot: {col}, min zz: {zz.min():0.2f}; max zz: {zz.max():0.2f}; min z: {z.min():0.2f}; max z: {z.max():0.2f}'
        )

    return im
예제 #19
0
    'pop_size': pop_size,  # start with a small pool
    'pop_type': pop_type,  # synthpops, hybrid
    'contacts': contacts[pop_type],
    'n_days': 1,
    # 'rand_seed': None,
}

# Create sim
sim = cv.Sim(pars=pars)
sim.initialize()

mnet = pymnet.MultilayerNetwork(aspects=1)

# fig = pl.figure(figsize=(16,16), dpi=120)
mapping = dict(a='All', h='Households', s='Schools', w='Work', c='Community')
colors = sc.vectocolor(sim.people.age, cmap='turbo')

keys = list(contacts[pop_type].keys())
keys.remove('c')
# nrowcol = np.ceil(np.sqrt(len(keys)))

G = nx.MultiGraph()

node_set = set()
home_set = set()
school_set = set()
work_set = set()

sample_size = 1

for p in range(sample_size):
예제 #20
0
    def plot_histograms(self, start_day=None, end_day=None, bins=None, width=0.8, fig_args=None, font_size=18):
        '''
        Plots a histogram of the number of transmissions.

        Args:
            start_day (int/str): the day on which to start counting people who got infected
            end_day (int/str): the day on which to stop counting people who got infected
            bins (list): bin edges to use for the histogram
            width (float): width of bars
            fig_args (dict): passed to pl.figure()
            font_size (float): size of font
        '''

        # Process targets
        n_targets = self.count_targets(start_day, end_day)

        # Handle bins
        if bins is None:
            max_infections = n_targets.max()
            bins = np.arange(0, max_infections+2)

        # Analysis
        counts = np.histogram(n_targets, bins)[0]

        bins = bins[:-1] # Remove last bin since it's an edge
        total_counts = counts*bins
        n_bins = len(bins)
        index = np.linspace(0, 100, len(n_targets))
        sorted_arr = np.sort(n_targets)
        sorted_sum = np.cumsum(sorted_arr)
        sorted_sum = sorted_sum/sorted_sum.max()*100
        change_inds = sc.findinds(np.diff(sorted_arr) != 0)
        max_labels = 15 # Maximum number of ticks and legend entries to plot

        # Plotting
        fig_args = sc.mergedicts(dict(figsize=(24,15)), fig_args)
        pl.rcParams['font.size'] = font_size
        fig = pl.figure(**fig_args)
        pl.set_cmap('Spectral')
        pl.subplots_adjust(left=0.08, right=0.92, bottom=0.08, top=0.92)
        colors = sc.vectocolor(n_bins)

        pl.subplot(1,2,1)
        w05 = width*0.5
        w025 = w05*0.5
        pl.bar(bins-w025, counts, width=w05, facecolor='k', label='Number of events')
        for i in range(n_bins):
            label = 'Number of transmissions (events × transmissions per event)' if i==0 else None
            pl.bar(bins[i]+w025, total_counts[i], width=w05, facecolor=colors[i], label=label)
        pl.xlabel('Number of transmissions per person')
        pl.ylabel('Count')
        if n_bins<max_labels:
            pl.xticks(ticks=bins)
        pl.legend()
        pl.title('Numbers of events and transmissions')

        pl.subplot(2,2,2)
        total = 0
        for i in range(n_bins):
            pl.bar(bins[i:], total_counts[i], width=width, bottom=total, facecolor=colors[i])
            total += total_counts[i]
        if n_bins<max_labels:
            pl.xticks(ticks=bins)
        pl.xlabel('Number of transmissions per person')
        pl.ylabel('Number of infections caused')
        pl.title('Number of transmissions, by transmissions per person')

        pl.subplot(2,2,4)
        pl.plot(index, sorted_sum, lw=3, c='k', alpha=0.5)
        n_change_inds = len(change_inds)
        label_inds = np.linspace(0, n_change_inds, max_labels).round() # Don't allow more than this many labels
        for i in range(n_change_inds):
            if i in label_inds: # Don't plot more than this many labels
                label = f'Transmitted to {bins[i+1]:n} people'
            else:
                label = None
            pl.scatter([index[change_inds[i]]], [sorted_sum[change_inds[i]]], s=150, zorder=10, c=[colors[i]], label=label)
        pl.xlabel('Proportion of population, ordered by the number of people they infected (%)')
        pl.ylabel('Proportion of infections caused (%)')
        pl.legend()
        pl.ylim([0, 100])
        pl.grid(True)
        pl.title('Proportion of transmissions, by proportion of population')

        pl.axes([0.30, 0.65, 0.15, 0.2])
        berry      = [0.8, 0.1, 0.2]
        dirty_snow = [0.9, 0.9, 0.9]
        start_day  = self.day(start_day, which='start')
        end_day    = self.day(end_day, which='end')
        pl.axvspan(start_day, end_day, facecolor=dirty_snow)
        pl.plot(self.sim_results['t'], self.sim_results['cum_infections'], lw=2, c=berry)
        pl.xlabel('Day')
        pl.ylabel('Cumulative infections')


        return fig