def rebin_matrix_by_age(matrix, datadir, location="seattle_metro", state_location="Washington", country_location="usa"): """ Helper method to get the average of contact matrix by age brackets @TODO: should we merge the functionalities with sp.get_aggregate_matrix or remove as this operation may not be scientifically meaningful (?) Args: matrix : raw matrix with single age bracket datadir : data directory state_location : state location country_location : country location Returns: numpy.ndarray: A matrix with desired age bracket with average values for all cells. """ brackets = sp.get_census_age_brackets(datadir, location, state_location, country_location) ageindex = sp.get_age_by_brackets_dic(brackets) agg_matrix = sp.get_aggregate_matrix(matrix, ageindex) counter = Counter(ageindex.values()) # number of ageindex per bracket for i in range(0, len(counter)): for j in range(0, len(counter)): agg_matrix[i, j] /= (counter[i] * counter[j]) return agg_matrix
def get_pop_details(self, pop, dir, title_prefix, location, state_location, country_location, decimal=3): os.makedirs(dir, exist_ok=True) for setting_code in ['H', 'W', 'S']: average_contacts = utilities.get_average_contact_by_age( pop, self.datadir, setting_code=setting_code, decimal=decimal) fmt = f'%.{str(decimal)}f' # print(f"expected contacts by age for {code}:\n", average_contacts) utilities.plot_array( average_contacts, datadir=self.figDir, testprefix= f"{self.n}_seed_{self.seed}_{setting_code}_average_contacts", expect_label='Expected' if self.generateBaseline else 'Test') sc.savejson(os.path.join( dir, f"{self.n}_seed_{self.seed}_{setting_code}_average_contact.json" ), dict(enumerate(average_contacts.tolist())), indent=2) for type in ['density', 'frequency']: matrix = sp.calculate_contact_matrix(pop, type, setting_code) brackets = sp.get_census_age_brackets(self.datadir, state_location, country_location) ageindex = sp.get_age_by_brackets_dic(brackets) agg_matrix = sp.get_aggregate_matrix(matrix, ageindex) np.savetxt(os.path.join( dir, f"{self.n}_seed_{self.seed}_{setting_code}_{type}_contact_matrix.csv" ), agg_matrix, delimiter=",", fmt=fmt) fig = plot_age_mixing_matrices.test_plot_generated_contact_matrix( setting_code=setting_code, population=pop, title_prefix=" Expected " if self.generateBaseline else " Test ", density_or_frequency=type) # fig.show() fig.savefig( os.path.join( self.figDir, f"{self.n}_seed_{self.seed}_{setting_code}_{type}_contact_matrix.png" ))
def plot_contact_matrix(matrix, age_count, aggregate_age_count, age_brackets, age_by_brackets_dic, setting_code='H', density_or_frequency='density', logcolors_flag=False, aggregate_flag=True): # cmap = mplt.cm.get_cmap(cmocean.cm.deep_r) cmap = mplt.cm.get_cmap(cmocean.cm.matter_r) fig = plt.figure(figsize=(7, 7), tight_layout=True) ax = fig.add_subplot(111) titles = {'H': 'Household', 'S': 'School', 'W': 'Work'} if aggregate_flag: num_agebrackets = len(age_brackets) aggregate_M = sp.get_aggregate_matrix(matrix, age_by_brackets_dic) asymmetric_M = sp.get_asymmetric_matrix(aggregate_M, aggregate_age_count) else: num_agebrackets = len(age_brackets) asymmetric_M = sp.get_asymmetric_matrix(matrix, age_count) if logcolors_flag: vbounds = {} if density_or_frequency == 'density': if aggregate_flag: vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e-0} vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e1} vbounds['W'] = {'vmin': 1e-3, 'vmax': 1e1} else: vbounds['H'] = {'vmin': 1e-3, 'vmax': 1e-1} vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e-1} vbounds['W'] = {'vmin': 1e-3, 'vmax': 1e-1} elif density_or_frequency == 'frequency': if aggregate_flag: vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e0} vbounds['S'] = {'vmin': 1e-3, 'vmax': 1e1} vbounds['W'] = {'vmin': 1e-2, 'vmax': 1e0} else: vbounds['H'] = {'vmin': 1e-2, 'vmax': 1e0} vbounds['S'] = {'vmin': 1e-2, 'vmax': 1e0} vbounds['W'] = {'vmin': 1e-2, 'vmax': 1e0} im = ax.imshow(asymmetric_M.T, origin='lower', interpolation='nearest', cmap=cmap, norm=LogNorm(vmin=vbounds[setting_code]['vmin'], vmax=vbounds[setting_code]['vmax'])) else: im = ax.imshow(asymmetric_M.T, origin='lower', interpolation='nearest', cmap=cmap) implot = im divider = make_axes_locatable(ax) cax = divider.new_horizontal(size="3.5%", pad=0.1) fig.add_axes(cax) cbar = fig.colorbar(implot, cax=cax) cbar.ax.tick_params(axis='y', labelsize=20) if density_or_frequency == 'frequency': cbar.ax.set_ylabel('Frequency of Contacts', fontsize=20) else: cbar.ax.set_ylabel('Density of Contacts', fontsize=20) ax.tick_params(labelsize=20) ax.set_xlabel('Age', fontsize=22) ax.set_ylabel('Age of Contacts', fontsize=22) ax.set_title(titles[setting_code] + ' Contact Patterns', fontsize=28) if aggregate_flag: tick_labels = [ str(age_brackets[b][0]) + '-' + str(age_brackets[b][-1]) for b in age_brackets ] ax.set_xticks(np.arange(len(tick_labels))) ax.set_xticklabels(tick_labels, fontsize=18) ax.set_xticklabels(tick_labels, fontsize=18, rotation=50) ax.set_yticks(np.arange(len(tick_labels))) ax.set_yticklabels(tick_labels, fontsize=18) return fig
im = [] cax = [] cbar = [] rotation = 66 for l, layer in enumerate(keys_to_plot): setting_code = layer.title() if setting_code == 'L': setting_code = 'LTCF' print(f'Plotting average age mixing contact matrix in layer: {layer}') matrix_dic[layer] = sp.calculate_contact_matrix(population, density_or_frequency, setting_code) if aggregate_flag: aggregate_matrix = sp.get_aggregate_matrix(matrix_dic[layer], age_by_brackets_dic) asymmetric_matrix = sp.get_asymmetric_matrix(aggregate_matrix, aggregate_age_count) else: asymmetric_matrix = sp.get_asymmetric_matrix(matrix_dic[layer], age_count) im.append(ax[2 * l].imshow(asymmetric_matrix.T, origin='lower', interpolation='nearest', cmap=matrix_cmap, norm=LogNorm( vmin=vbounds[setting_code]['vmin'], vmax=vbounds[setting_code]['vmax'])))
n_contacts = (sim.people.contacts[layer]['p1'] == p).sum() contact_ages = ages[contacts] for ca in contact_ages: symmetric_matrix[ages[p]][ca] += 1 n_contacts_count[ages[p]] += n_contacts age_brackets = sp.get_census_age_brackets(sp.datadir, state_location='Washington', country_location='usa') age_by_brackets_dic = sp.get_age_by_brackets_dic(age_brackets) aggregate_age_count = sp.get_aggregate_ages(age_count, age_by_brackets_dic) aggregate_matrix = symmetric_matrix.copy() aggregate_matrix = sp.get_aggregate_matrix(aggregate_matrix, age_by_brackets_dic) asymmetric_matrix = sp.get_asymmetric_matrix(aggregate_matrix, aggregate_age_count) fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(111) im = ax.imshow(asymmetric_matrix.T, origin='lower', interpolation='nearest', cmap=cmap, norm=LogNorm(vmin=1e-1, vmax=1e1)) # im = ax.imshow(asymmetric_matrix.T, origin='lower', interpolation='nearest', cmap=cmap, ) divider = make_axes_locatable(ax) cax = divider.new_horizontal(size="4%", pad=0.15)