コード例 #1
0
ファイル: test_plot_tools.py プロジェクト: tianyuan/synthpops
def test_plot_generated_trimmed_contact_matrix(setting_code='H', n=5000, aggregate_flag=True, logcolors_flag=True,
                                               density_or_frequency='density'):
    datadir = sp.datadir

    state_location = 'Washington'
    location = 'seattle_metro'
    country_location = 'usa'

    popdict = {}

    options_args = {'use_microstructure': True}
    network_distr_args = {'Npop': int(n)}
    contacts = sp.make_contacts(popdict, state_location=state_location, location=location, options_args=options_args,
                                network_distr_args=network_distr_args)
    contacts = sp.trim_contacts(contacts, trimmed_size_dic=None, use_clusters=False)

    age_brackets = sp.get_census_age_brackets(datadir, state_location=state_location, country_location=country_location)
    age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

    ages = []
    for uid in contacts:
        ages.append(contacts[uid]['age'])

    age_count = Counter(ages)
    aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic)

    freq_matrix_dic = sp.calculate_contact_matrix(contacts, density_or_frequency)

    fig = sp.plot_contact_frequency(freq_matrix_dic, age_count, aggregate_age_count, age_brackets, age_by_brackets_dic,
                                    setting_code, density_or_frequency, logcolors_flag, aggregate_flag)

    return fig
コード例 #2
0
def plot_generated_trimmed_contact_matrix(datadir,
                                          n,
                                          location='seattle_metro',
                                          state_location='Washington',
                                          country_location='usa',
                                          setting_code='H',
                                          aggregate_flag=True,
                                          logcolors_flag=True,
                                          density_or_frequency='density',
                                          trimmed_size_dic=None):

    popdict = {}

    options_args = {'use_microstructure': True}
    network_distr_args = {'Npop': int(n)}
    contacts = sp.make_contacts(popdict,
                                country_location=country_location,
                                state_location=state_location,
                                location=location,
                                options_args=options_args,
                                network_distr_args=network_distr_args)
    contacts = sp.trim_contacts(contacts,
                                trimmed_size_dic=trimmed_size_dic,
                                use_clusters=False)

    age_brackets = sp.get_census_age_brackets(
        datadir,
        state_location=state_location,
        country_location=country_location)
    age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

    ages = []
    for uid in contacts:
        ages.append(contacts[uid]['age'])

    num_agebrackets = len(age_brackets)

    age_count = Counter(ages)
    aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic)

    symmetric_matrix = calculate_contact_matrix(contacts, density_or_frequency,
                                                setting_code)

    fig = plot_contact_matrix(symmetric_matrix,
                              age_count,
                              aggregate_age_count,
                              age_brackets,
                              age_by_brackets_dic,
                              setting_code=setting_code,
                              density_or_frequency=density_or_frequency,
                              logcolors_flag=logcolors_flag,
                              aggregate_flag=aggregate_flag)
    return fig
コード例 #3
0
ファイル: test_plot_tools.py プロジェクト: sba5827/synthpops
def test_plot_generated_trimmed_contact_matrix(setting_code='H', n=5000, aggregate_flag=True, logcolors_flag=True,
                                               density_or_frequency='density', with_facilities=False, cmap='cmr.freeze_r', fontsize=16, rotation=50):
    """
    Plot the age mixing matrix for a specific setting where the edges are trimmed.

    Args:
        setting_code (str)               : name of the physial contact setting: H for households, S for schools, W for workplaces, C for community or other
        n (int)                          : number of people in the population
        aggregate_flag (book)            : If True, plot the contact matrix for aggregate age brackets, else single year age contact matrix.
        logcolors_flag (bool)            : If True, plot heatmap in logscale
        density_or_frequency (str)       : If 'density', then each contact counts for 1/(group size -1) of a person's contact in a group, elif 'frequency' then count each contact. This means that more people in a group leads to higher rates of contact/exposure.
        with_facilities (bool)           : If True, create long term care facilities
        cmap(str or matplotlib colormap) : colormap
        fontsize (int)                   : base font size
        rotation (int)                   : rotation for x axis labels

    Returns:
        A fig object.

    """
    datadir = sp.datadir

    state_location = 'Washington'
    location = 'seattle_metro'
    country_location = 'usa'

    # popdict = {}

    options_args = {'use_microstructure': True}
    network_distr_args = {'Npop': int(n)}
    # contacts = sp.make_contacts(popdict, state_location=state_location, location=location, options_args=options_args,
    #                             network_distr_args=network_distr_args)
    # contacts = sp.trim_contacts(contacts, trimmed_size_dic=None, use_clusters=False)

    population = sp.make_population(n, generate=True, with_facilities=with_facilities)

    age_brackets = sp.get_census_age_brackets(datadir, state_location=state_location, country_location=country_location)
    age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

    ages = []
    for uid in population:
        ages.append(population[uid]['age'])

    age_count = Counter(ages)
    aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic)

    matrix = sp.calculate_contact_matrix(population, density_or_frequency, setting_code)

    fig = sp.plot_contact_matrix(matrix, age_count, aggregate_age_count, age_brackets, age_by_brackets_dic,
                                 setting_code, density_or_frequency, logcolors_flag, aggregate_flag, cmap, fontsize, rotation)

    return fig
コード例 #4
0
def plot_contact_matrix_after_intervention(n,
                                           n_days,
                                           interventions,
                                           intervention_name,
                                           location='seattle_metro',
                                           state_location='Washington',
                                           country_location='usa',
                                           aggregate_flag=True,
                                           logcolors_flag=True,
                                           density_or_frequency='density',
                                           setting_code='H',
                                           cmap='cmr.freeze_r',
                                           fontsize=16,
                                           rotation=50):
    """
    Args:
        intervention (cv.intervention): a single intervention
    """
    pars = sc.objdict(pop_size=n, n_days=n_days, pop_type='synthpops')

    # sim = sc.objdict()
    sim = cv.Sim(pars=pars, interventions=interventions)
    sim.run()

    age_brackets = sp.get_census_age_brackets(
        sp.datadir,
        state_location=state_location,
        country_location=country_location)
    age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

    ages = sim.people.age
    ages = np.round(ages, 1)
    ages = ages.astype(int)
    max_age = max(ages)
    age_count = Counter(ages)
    age_count = dict(age_count)
    for i in range(max_age + 1):
        if i not in age_count:
            age_count[i] = 0

    aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic)

    matrix = calculate_contact_matrix(sim, density_or_frequency, setting_code)

    fig = sp.plot_contact_matrix(matrix, age_count, aggregate_age_count,
                                 age_brackets, age_by_brackets_dic,
                                 setting_code, density_or_frequency,
                                 logcolors_flag, aggregate_flag, cmap,
                                 fontsize, rotation)

    return fig
コード例 #5
0
def process_age_tables():
    """Function to preprocess age tables."""
    file_path = os.path.join(dir_path, 'Series A. Population Tables.xlsx')
    df = pd.read_excel(file_path,
                       sheet_name='A5',
                       header=1,
                       skiprows=[2, 3],
                       skipfooter=303)

    ages = df['Age in single Years'].values[1:]
    age_count = np.array(df['National'].values[1:])
    age_range = np.arange(len(ages))

    age_dist = age_count / age_count.sum()
    age_dist_mapping = dict(zip(age_range, age_dist))

    data = dict(age_min=sc.dcp(age_range),
                age_max=sc.dcp(age_range),
                age_dist=age_dist)
    data['age_max'][-1] = 100
    new_df = pd.DataFrame.from_dict(data)

    new_file_path = os.path.join(dir_path, 'Malawi_national_ages.csv')
    new_df.to_csv(new_file_path, index=False)

    census_age_brackets = sp.get_census_age_brackets(
        sp.settings.datadir,
        location='seattle-metro',
        state_location='Washington',
        country_location='usa',
        nbrackets=16)
    census_age_by_brackets = sp.get_age_by_brackets(census_age_brackets)

    agg_ages = sp.get_aggregate_ages(age_dist_mapping, census_age_by_brackets)

    agg_data = dict()
    agg_data['age_min'] = np.array(
        [census_age_brackets[b][0] for b in census_age_brackets])
    agg_data['age_max'] = np.array(
        [census_age_brackets[b][-1] for b in census_age_brackets])
    agg_data['age_dist'] = np.array(
        [agg_ages[b] for b in sorted(census_age_brackets.keys())])
    agg_df = pd.DataFrame.from_dict(agg_data)
    print(agg_df)
    agg_path = os.path.join(dir_path, 'Malawi_national_ages_16.csv')
    agg_df.to_csv(agg_path, index=False)
コード例 #6
0
generate = True
population = sp.make_population(n,
                                generate=generate,
                                with_facilities=with_facilities)

# aggregate age brackets
age_brackets = sp.get_census_age_brackets(sp.datadir,
                                          state_location=state_location,
                                          country_location=country_location)
age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

ages = []
for uid in population:
    ages.append(population[uid]['age'])
age_count = Counter(ages)
aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic)

matrix_dic = {}
density_or_frequency = 'density'
# density_or_frequency = 'frequency'

aggregate_flag = True
# aggregate_flag = False

logcolors_flag = True
logcolors_flag = False
# log color bounds

matrix_cmap = 'cmr.freeze_r'

vbounds = {}