コード例 #1
0
ファイル: sampling.py プロジェクト: caseypeter10/synthpops
def get_n_contact_ids_by_age(contact_ids_by_age_dic, contact_ages,
                             age_brackets, age_by_brackets_dic):
    """
    Get ids for the contacts with ages in contact_ages.

    Args:
        contact_ids_by_age_dic (dict): dictionary mapping lists of ids to the age of individuals with those ids
        contact_ages (list)          : list of integer ages
        age_brackets (dict)          : dictionary mapping age bracket keys to age bracket range
        age_by_brackets_dic (dict)   : dictionary mapping age to the age bracket range it falls in

    Return set of ids of n_contacts sampled from an age mixing matrix, where potential contacts are chosen from a list of contact ids by age
    """
    contact_ids = set()
    for contact_age in contact_ages:
        age_list = sorted(list(contact_ids_by_age_dic.keys()))
        ind = sc.findnearest(age_list, contact_age)
        these_ids = contact_ids_by_age_dic[age_list[ind]]
        if len(these_ids) > 0:
            contact_id = np.random.choice(these_ids)
        else:
            b_contact = age_by_brackets_dic[contact_age]
            potential_contacts = []
            for a in age_brackets[b_contact]:
                potential_contacts += contact_ids_by_age_dic[a]
            contact_id = np.random.choice(potential_contacts)
        contact_ids.add(contact_id)
    return contact_ids
コード例 #2
0
ファイル: draw_networks_v5.py プロジェクト: sba5827/synthpops
    colors = np.array(colors)
else:
    age_map = sim.people.age * 0.1 + np.sqrt(sim.people.age)
    colors = sc.vectocolor(age_map, cmap=cmap)

# Create the legend
if plot_stacked:
    ax = fig.add_axes([0.85, 0.05, 0.14, 0.93])
elif len(keys_to_plot) % 2 != 0:
    ax = fig.add_axes([0.85, 0.05, 0.14, 0.93])
else:
    ax = fig.add_axes([0.82, 0.05, 0.14, 0.90])

ax.axis('off')
for age in age_cutoffs:
    nearest_age = sc.findnearest(sim.people.age, age)
    col = colors[nearest_age]
    if age != 100:
        plt.plot(np.nan, np.nan, 'o', c=col, label=f'Age {age}-{age+9}')
    else:
        plt.plot(np.nan, np.nan, 'o', c=col, label=f'Age {age}+')
plt.legend(fontsize=18)

# Find indices
idict = {}
hdfs = {}
for layer in keys:
    hdf = sim.people.contacts[layer].to_df()

    hdfs[layer] = hdf
    idict[layer] = list(
コード例 #3
0
def plotter(key,
            sims,
            ax,
            ys=None,
            calib=False,
            label='',
            ylabel='',
            low_q=0.025,
            high_q=0.975,
            flabel=True,
            startday=None,
            subsample=2,
            chooseseed=None):

    which = key.split('_')[1]
    try:
        color = cv.get_colors()[which]
    except:
        color = [0.5, 0.5, 0.5]
    if which == 'diagnoses':
        color = [0.03137255, 0.37401, 0.63813918, 1.]
    elif which == '':
        color = [0.82400815, 0., 0., 1.]

    if ys is None:
        ys = []
        for s in sims:
            ys.append(s.results[key].values)

    yarr = np.array(ys)
    if chooseseed is not None:
        best = sims[chooseseed].results[key].values
    else:
        best = pl.median(yarr, axis=0)
    low = pl.quantile(yarr, q=low_q, axis=0)
    high = pl.quantile(yarr, q=high_q, axis=0)

    sim = sims[0]  # For having a sim to refer to

    tvec = np.arange(len(best))
    if key in sim.data:
        data_t = np.array(
            (sim.data.index - sim['start_day']) / np.timedelta64(1, 'D'))
        inds = np.arange(0, len(data_t), subsample)
        pl.plot(data_t[inds],
                sim.data[key][inds],
                'd',
                c=color,
                markersize=15,
                alpha=0.75,
                label='Data')

    start = None
    if startday is not None:
        start = sim.day(startday)
    end = sim.day(calibration_end)
    if flabel:
        if which == 'infections':
            fill_label = '95% projected interval'
        else:
            fill_label = '95% projected interval'
    else:
        fill_label = None
    pl.fill_between(tvec[startday:end],
                    low[startday:end],
                    high[startday:end],
                    facecolor=color,
                    alpha=0.2,
                    label=fill_label)
    pl.plot(tvec[startday:end],
            best[startday:end],
            c=color,
            label=label,
            lw=4,
            alpha=1.0)

    # Print some stats
    if key == 'cum_infections':
        print(
            f'Estimated {which} on July 25: {best[sim.day("2020-07-25")]} (95%: {low[sim.day("2020-07-25")]}-{high[sim.day("2020-07-25")]})'
        )
        print(
            f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})'
        )
    elif key == 'n_infectious':
        peakday = sc.findnearest(best, max(best))
        peakval = max(best)
        print(
            f'Estimated peak {which} on {sim.date(peakday)}: {peakval} (95%: {low[peakday]}-{high[peakday]})'
        )
        print(
            f'Estimated {which} on last day: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})'
        )
    elif key == 'cum_diagnoses':
        print(
            f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})'
        )

    sc.setylim()

    xmin, xmax = ax.get_xlim()
    if calib:
        ax.set_xticks(pl.arange(xmin + 2, xmax, 28))
    else:
        ax.set_xticks(pl.arange(xmin + 2, xmax, 28))

    pl.ylabel(ylabel)
    datemarks = pl.array([
        sim.day('2020-07-01'),
        sim.day('2020-08-01'),
        sim.day('2020-09-01'),
        sim.day('2020-10-01')
    ]) * 1.
    ax.set_xticks(datemarks)

    return