示例#1
0
def rebin_matrix_by_age(matrix,
                        datadir,
                        location="seattle_metro",
                        state_location="Washington",
                        country_location="usa"):
    """
    Helper method to get the average of contact matrix by age brackets
    @TODO: should we merge the functionalities with sp.get_aggregate_matrix
    or remove as this operation may not be scientifically meaningful (?)

    Args:
        matrix           : raw matrix with single age bracket
        datadir          : data directory
        state_location   : state location
        country_location : country location

    Returns:
        numpy.ndarray: A matrix with desired age bracket with average values for all cells.

    """
    brackets = sp.get_census_age_brackets(datadir, location, state_location,
                                          country_location)
    ageindex = sp.get_age_by_brackets_dic(brackets)
    agg_matrix = sp.get_aggregate_matrix(matrix, ageindex)
    counter = Counter(ageindex.values())  # number of ageindex per bracket
    for i in range(0, len(counter)):
        for j in range(0, len(counter)):
            agg_matrix[i, j] /= (counter[i] * counter[j])
    return agg_matrix
示例#2
0
    def get_pop_details(self,
                        pop,
                        dir,
                        title_prefix,
                        location,
                        state_location,
                        country_location,
                        decimal=3):
        os.makedirs(dir, exist_ok=True)
        for setting_code in ['H', 'W', 'S']:
            average_contacts = utilities.get_average_contact_by_age(
                pop, self.datadir, setting_code=setting_code, decimal=decimal)
            fmt = f'%.{str(decimal)}f'
            # print(f"expected contacts by age for {code}:\n", average_contacts)
            utilities.plot_array(
                average_contacts,
                datadir=self.figDir,
                testprefix=
                f"{self.n}_seed_{self.seed}_{setting_code}_average_contacts",
                expect_label='Expected' if self.generateBaseline else 'Test')
            sc.savejson(os.path.join(
                dir,
                f"{self.n}_seed_{self.seed}_{setting_code}_average_contact.json"
            ),
                        dict(enumerate(average_contacts.tolist())),
                        indent=2)

            for type in ['density', 'frequency']:
                matrix = sp.calculate_contact_matrix(pop, type, setting_code)
                brackets = sp.get_census_age_brackets(self.datadir,
                                                      state_location,
                                                      country_location)
                ageindex = sp.get_age_by_brackets_dic(brackets)
                agg_matrix = sp.get_aggregate_matrix(matrix, ageindex)
                np.savetxt(os.path.join(
                    dir,
                    f"{self.n}_seed_{self.seed}_{setting_code}_{type}_contact_matrix.csv"
                ),
                           agg_matrix,
                           delimiter=",",
                           fmt=fmt)
                fig = plot_age_mixing_matrices.test_plot_generated_contact_matrix(
                    setting_code=setting_code,
                    population=pop,
                    title_prefix=" Expected "
                    if self.generateBaseline else " Test ",
                    density_or_frequency=type)
                # fig.show()
                fig.savefig(
                    os.path.join(
                        self.figDir,
                        f"{self.n}_seed_{self.seed}_{setting_code}_{type}_contact_matrix.png"
                    ))
示例#3
0
def plot_contact_matrix(matrix,
                        age_count,
                        aggregate_age_count,
                        age_brackets,
                        age_by_brackets_dic,
                        setting_code='H',
                        density_or_frequency='density',
                        logcolors_flag=False,
                        aggregate_flag=True):

    # cmap = mplt.cm.get_cmap(cmocean.cm.deep_r)
    cmap = mplt.cm.get_cmap(cmocean.cm.matter_r)

    fig = plt.figure(figsize=(7, 7), tight_layout=True)
    ax = fig.add_subplot(111)

    titles = {'H': 'Household', 'S': 'School', 'W': 'Work'}

    if aggregate_flag:
        num_agebrackets = len(age_brackets)
        aggregate_M = sp.get_aggregate_matrix(matrix, age_by_brackets_dic)
        asymmetric_M = sp.get_asymmetric_matrix(aggregate_M,
                                                aggregate_age_count)
    else:
        num_agebrackets = len(age_brackets)
        asymmetric_M = sp.get_asymmetric_matrix(matrix, age_count)

    if logcolors_flag:

        vbounds = {}
        if density_or_frequency == 'density':
            if aggregate_flag:
                vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e-0}
                vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e1}
                vbounds['W'] = {'vmin': 1e-3, 'vmax': 1e1}
            else:
                vbounds['H'] = {'vmin': 1e-3, 'vmax': 1e-1}
                vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e-1}
                vbounds['W'] = {'vmin': 1e-3, 'vmax': 1e-1}

        elif density_or_frequency == 'frequency':
            if aggregate_flag:
                vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e0}
                vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e1}
                vbounds['W'] = {'vmin': 1e-2, 'vmax': 1e0}
            else:
                vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e0}
                vbounds['S'] = {'vmin': 1e-2, 'vmax': 1e0}
                vbounds['W'] = {'vmin': 1e-2, 'vmax': 1e0}
        im = ax.imshow(asymmetric_M.T,
                       origin='lower',
                       interpolation='nearest',
                       cmap=cmap,
                       norm=LogNorm(vmin=vbounds[setting_code]['vmin'],
                                    vmax=vbounds[setting_code]['vmax']))
    else:
        im = ax.imshow(asymmetric_M.T,
                       origin='lower',
                       interpolation='nearest',
                       cmap=cmap)
    implot = im

    divider = make_axes_locatable(ax)
    cax = divider.new_horizontal(size="3.5%", pad=0.1)

    fig.add_axes(cax)
    cbar = fig.colorbar(implot, cax=cax)
    cbar.ax.tick_params(axis='y', labelsize=20)
    if density_or_frequency == 'frequency':
        cbar.ax.set_ylabel('Frequency of Contacts', fontsize=20)
    else:
        cbar.ax.set_ylabel('Density of Contacts', fontsize=20)
    ax.tick_params(labelsize=20)
    ax.set_xlabel('Age', fontsize=22)
    ax.set_ylabel('Age of Contacts', fontsize=22)
    ax.set_title(titles[setting_code] + ' Contact Patterns', fontsize=28)

    if aggregate_flag:
        tick_labels = [
            str(age_brackets[b][0]) + '-' + str(age_brackets[b][-1])
            for b in age_brackets
        ]
        ax.set_xticks(np.arange(len(tick_labels)))
        ax.set_xticklabels(tick_labels, fontsize=18)
        ax.set_xticklabels(tick_labels, fontsize=18, rotation=50)
        ax.set_yticks(np.arange(len(tick_labels)))
        ax.set_yticklabels(tick_labels, fontsize=18)

    return fig
示例#4
0
im = []
cax = []
cbar = []
rotation = 66

for l, layer in enumerate(keys_to_plot):
    setting_code = layer.title()
    if setting_code == 'L':
        setting_code = 'LTCF'
    print(f'Plotting average age mixing contact matrix in layer: {layer}')
    matrix_dic[layer] = sp.calculate_contact_matrix(population,
                                                    density_or_frequency,
                                                    setting_code)

    if aggregate_flag:
        aggregate_matrix = sp.get_aggregate_matrix(matrix_dic[layer],
                                                   age_by_brackets_dic)
        asymmetric_matrix = sp.get_asymmetric_matrix(aggregate_matrix,
                                                     aggregate_age_count)

    else:
        asymmetric_matrix = sp.get_asymmetric_matrix(matrix_dic[layer],
                                                     age_count)

    im.append(ax[2 * l].imshow(asymmetric_matrix.T,
                               origin='lower',
                               interpolation='nearest',
                               cmap=matrix_cmap,
                               norm=LogNorm(
                                   vmin=vbounds[setting_code]['vmin'],
                                   vmax=vbounds[setting_code]['vmax'])))
示例#5
0
                n_contacts = (sim.people.contacts[layer]['p1'] == p).sum()
                contact_ages = ages[contacts]

                for ca in contact_ages:
                    symmetric_matrix[ages[p]][ca] += 1
                n_contacts_count[ages[p]] += n_contacts

        age_brackets = sp.get_census_age_brackets(sp.datadir,
                                                  state_location='Washington',
                                                  country_location='usa')
        age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets)

        aggregate_age_count = sp.get_aggregate_ages(age_count,
                                                    age_by_brackets_dic)
        aggregate_matrix = symmetric_matrix.copy()
        aggregate_matrix = sp.get_aggregate_matrix(aggregate_matrix,
                                                   age_by_brackets_dic)

        asymmetric_matrix = sp.get_asymmetric_matrix(aggregate_matrix,
                                                     aggregate_age_count)

        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111)
        im = ax.imshow(asymmetric_matrix.T,
                       origin='lower',
                       interpolation='nearest',
                       cmap=cmap,
                       norm=LogNorm(vmin=1e-1, vmax=1e1))
        # im = ax.imshow(asymmetric_matrix.T, origin='lower', interpolation='nearest', cmap=cmap, )

        divider = make_axes_locatable(ax)
        cax = divider.new_horizontal(size="4%", pad=0.15)