def on_collect_metrics(self, event): """Records counts of risk exposed by category.""" pop = self.population_view.get(event.index) if self.should_sample(event.time): age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( self.config, self.age_bins) group_counts = {} exposure = self.exposure(pop.index) for group, age_group in ages: start, end = age_group.age_group_start, age_group.age_group_end for sex in sexes: filter_kwargs = { 'age_group_start': start, 'age_group_end': end, 'sex': sex, 'age_group': group } group_filter = age_sex_filter.format(**filter_kwargs) in_group = pop.query( group_filter ) if group_filter and not pop.empty else pop for cat in self.categories: base_key = get_output_template( **self.config).substitute( measure=f'{self.risk.name}_{cat}_exposed', year=self.clock().year) group_key = base_key.substitute(**filter_kwargs) group_counts[group_key] = ( exposure.loc[in_group.index] == cat).sum() self.category_counts.update(group_counts)
def on_collect_metrics(self, event): """Records counts of risk exposed by category.""" pop = self.pop_view.get(event.index, query='alive == "alive"') if self.should_sample(event.time): age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( self.observer_config, self.age_bins) group_counts = {} for group, age_group in ages: start, end = age_group.age_group_start, age_group.age_group_end for sex in sexes: filter_kwargs = { 'age_group_start': start, 'age_group_end': end, 'sex': sex, 'age_group': group } group_filter = age_sex_filter.format(**filter_kwargs) in_group = pop.query( group_filter ) if group_filter and not pop.empty else pop anemia = self.split_for_anemia(in_group.index) exposed_count = 0 for level, idx in zip(['mild', 'moderate', 'severe'], anemia): base_key = get_output_template( **self.observer_config).substitute( measure=f'{level}_anemia_counts', year=self.clock().year) group_key = base_key.substitute(**filter_kwargs) level_count = len(in_group[idx]) group_counts[group_key] = level_count exposed_count += level_count base_key = get_output_template( **self.observer_config).substitute( measure='unexposed_anemia_counts', year=self.clock().year) group_key = base_key.substitute(**filter_kwargs) group_counts[group_key] = len(in_group) - exposed_count self.anemia_counts.update(group_counts)
def test_output_template(observer_config, measure, sex, age, year): template = get_output_template(**observer_config) out1 = template.substitute(measure=measure, sex=sex, age_group=age, year=year) out2 = template.substitute(measure=measure).substitute(sex=sex).substitute( age_group=age).substitute(year=year) assert out1 == out2
def test_get_output_template(observer_config): template = get_output_template(**observer_config) assert isinstance(template, OutputTemplate) assert '${measure}' in template.template if observer_config['by_year']: assert '_in_${year}' in template.template if observer_config['by_sex']: assert '_among_${sex}' in template.template if observer_config['by_age']: assert '_in_age_group_${age_group}' in template.template
def get_person_time(pop: pd.DataFrame, config: Dict[str, bool], current_year: Union[str, int], step_size: pd.Timedelta, age_bins: pd.DataFrame) -> Dict[str, float]: base_key = get_output_template(**config).substitute(measure='person_time', year=current_year) base_filter = QueryString(f'alive == "alive"') person_time = get_group_counts( pop, base_filter, base_key, config, age_bins, aggregate=lambda x: len(x) * to_years(step_size)) return person_time
def _get_birth_weight_sum(self, pop: pd.DataFrame, base_filter: QueryString, configuration: Dict, time_spans: List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]], age_bins: pd.DataFrame) -> Dict[str, float]: base_key = utilities.get_output_template(**configuration).substitute(measure='birth_weight_sum') birth_weight_sum = {} for year, (year_start, year_end) in time_spans: year_filter = base_filter.format(start_time=year_start, end_time=year_end) year_key = base_key.substitute(year=year) group_birth_weight_sums = utilities.get_group_counts(pop, year_filter, year_key, configuration, age_bins, lambda df: df[self.birth_weight_pipeline_name].sum()) birth_weight_sum.update(group_birth_weight_sums) return birth_weight_sum
def get_transition_count(pop: pd.DataFrame, config: Dict[str, bool], disease: str, transition: project_globals.TransitionString, event_time: pd.Timestamp, age_bins: pd.DataFrame) -> Dict[str, float]: """Counts transitions that occurred this step.""" event_this_step = ((pop[f'previous_{disease}'] == transition.from_state) & (pop[disease] == transition.to_state)) transitioned_pop = pop.loc[event_this_step] base_key = get_output_template(**config).substitute( measure=f'{transition}_event_count', year=event_time.year) base_filter = QueryString('') transition_count = get_group_counts(transitioned_pop, base_filter, base_key, config, age_bins) return transition_count
def get_births(pop: pd.DataFrame, config: Dict[str, bool], sim_start: pd.Timestamp, sim_end: pd.Timestamp) -> Dict[str, int]: """Counts the number of births and births with neural tube defects prevelant. Parameters ---------- pop The population dataframe to be counted. It must contain sufficient columns for any necessary filtering (e.g. the ``age`` column if filtering by age). config A dict with ``by_age``, ``by_sex``, and ``by_year`` keys and boolean values. sim_start The simulation start time. sim_end The simulation end time. Returns ------- births All births and births with neural tube defects present. """ base_filter = QueryString('') base_key = get_output_template(**config) time_spans = get_time_iterable(config, sim_start, sim_end) births = {} for year, (t_start, t_end) in time_spans: start = max(sim_start, t_start) end = min(sim_end, t_end) born_in_span = pop.query( f'"{start}" <= entrance_time and entrance_time < "{end}"') cat_year_key = base_key.substitute(measure='live_births', year=year) group_births = get_group_counts(born_in_span, base_filter, cat_year_key, config, pd.DataFrame()) births.update(group_births) cat_year_key = base_key.substitute(measure='born_with_ntds', year=year) filter_update = f'{project_globals.NTD_MODEL_NAME} == "{project_globals.NTD_MODEL_NAME}"' empty_age_bins = pd.DataFrame() group_ntd_births = get_group_counts(born_in_span, base_filter + filter_update, cat_year_key, config, empty_age_bins) births.update(group_ntd_births) return births
def test_output_template_exact(): template = get_output_template(by_age=True, by_sex=True, by_year=True) out = template.substitute(measure='Test', sex='Female', age_group=1.0, year=2011) expected = 'test_in_2011_among_female_in_age_group_1.0' assert out == expected out = template.substitute(measure='Test', sex='Female', age_group='Early Neonatal', year=2011) expected = 'test_in_2011_among_female_in_age_group_early_neonatal' assert out == expected
def on_collect_metrics(self, event): pop = self.population_view.get(event.index).query('alive == "alive"') pop = pop.loc[(self.clock() < pop.last_prescription_date) & (pop.last_prescription_date <= event.time)] pop['num_in_single_pill'] = pop[SINGLE_PILL_COLUMNS].sum(axis=1) base_key = get_output_template(**self.config).substitute( year=event.time.year) med_counts = {} for drug in HYPERTENSION_DRUGS: drug_pop = pop.loc[pop[f'{drug}_dosage'] > 0] if not drug_pop.empty: med_counts.update( self.summarize_drug_by_group(drug_pop, drug, base_key)) self.counts.update(med_counts)
def test_get_person_time_in_span(ages_and_bins, observer_config): _, age_bins = ages_and_bins start = int(age_bins.age_start.min()) end = int(age_bins.age_end.max()) n_ages = len(list(range(start, end))) n_bins = len(age_bins) segments_per_age = [(i + 1) * (n_ages - i) for i in range(n_ages)] ages_per_bin = n_ages // n_bins age_bins['expected_time'] = [ sum(segments_per_age[ages_per_bin * i:ages_per_bin * (i + 1)]) for i in range(n_bins) ] age_starts, age_ends = zip(*combinations(range(start, end + 1), 2)) women = pd.DataFrame({ 'age_at_span_start': age_starts, 'age_at_span_end': age_ends, 'sex': 'Female' }) men = women.copy() men.loc[:, 'sex'] = 'Male' lived_in_span = pd.concat( [women, men], ignore_index=True).sample(frac=1).reset_index(drop=True) base_filter = QueryString("") span_key = get_output_template(**observer_config).substitute( measure='person_time', year=2019) pt = get_person_time_in_span(lived_in_span, base_filter, span_key, observer_config, age_bins) if observer_config['by_age']: for group, age_bin in age_bins.iterrows(): group_pt = sum( set([v for k, v in pt.items() if f'in_age_group_{group}' in k])) if observer_config['by_sex']: assert group_pt == age_bin.expected_time else: assert group_pt == 2 * age_bin.expected_time else: group_pt = sum(set(pt.values())) if observer_config['by_sex']: assert group_pt == age_bins.expected_time.sum() else: assert group_pt == 2 * age_bins.expected_time.sum()
def get_state_person_time(pop: pd.DataFrame, config: Dict[str, bool], disease: str, state: str, current_year: Union[str, int], step_size: pd.Timedelta, age_bins: pd.DataFrame) -> Dict[str, float]: """Custom person time getter that handles state column name assumptions""" base_key = get_output_template(**config).substitute( measure=f'{state}_person_time', year=current_year) base_filter = QueryString(f'alive == "alive" and {disease} == "{state}"') person_time = get_group_counts( pop, base_filter, base_key, config, age_bins, aggregate=lambda x: len(x) * to_years(step_size)) return person_time
def on_time_step_prepare(self, event): # we count person time each time step if we are tracking WHZ pop = self.population_view.get(event.index) raw_whz_exposure = self.raw_whz_exposure(event.index) whz_exposure = convert_whz_to_categorical(raw_whz_exposure) for cat in whz_exposure.unique(): in_cat = pop.loc[whz_exposure == cat] base_filter = QueryString('alive == "alive"') base_key = get_output_template(**self.config) base_key = base_key.substitute(measure='person_time', year=self.clock().year) counts = get_group_counts(in_cat, base_filter, base_key, self.config.to_dict(), self.age_bins) counts = { str(key) + f'_in_{cat}': value * self.step_size for key, value in counts.items() } self.person_time.update(counts)
def on_time_step_prepare(self, event: 'Event'): pop = self.population_view.get(event.index) pop['anemia'] = self.anemia_severity(pop.index) # Ignoring the edge case where the step spans a new year. # Accrue all counts and time to the current year. for state in self.states: base_key = get_output_template(**self.config).substitute( measure=f'anemia_{state}_person_time', year=self.clock().year) base_filter = QueryString( f'alive == "alive" and anemia == "{state}"') # noinspection PyTypeChecker person_time = get_group_counts( pop, base_filter, base_key, self.config, self.age_bins, aggregate=lambda x: len(x) * to_years(event.step_size)) self.person_time.update(person_time)
def on_time_step_prepare(self, event): # I think this is right timing wise - I didn't want to do on collect metrics b/c if someone gets on tx during # a time step, it doesn't seem like their person time should be counted in the treated status base_filter = QueryString("") for key, index in self.get_groups(event.index).items(): pop = self.population_view.get(index) pop.loc[pop.exit_time.isna(), 'exit_time'] = self.clock() + self.step_size t_start = self.clock() t_end = self.clock() + self.step_size lived_in_span = get_lived_in_span(pop, t_start, t_end) span_key = get_output_template(**self.config.to_dict()).substitute( measure=f'person_time_{key}') person_time_in_span = get_person_time_in_span( lived_in_span, base_filter, span_key, self.config.to_dict(), self.age_bins) self.person_time.update(person_time_in_span)
def on_collect_metrics(self, event): base_key = get_output_template(**self.config).substitute( year=event.time.year) base_filter = QueryString('') pop = self.population_view.get(event.index) pop = pop.loc[pop[f'vaccine_event_time'] == event.time] dose_counts = {} for dose in project_globals.VACCINE_DOSES: dose_filter = base_filter + f'vaccine_dose == "{dose}"' group_counts = get_group_counts(pop, dose_filter, base_key, self.config, self.age_bins) for group_key, count in group_counts.items(): group_key = group_key.substitute( measure= f'{project_globals.SHIGELLA_VACCINE}_{dose}_dose_count') dose_counts[group_key] = count self.counts.update(dose_counts)
def _get_births(self, pop: pd.DataFrame, base_filter: QueryString, configuration: Dict, time_spans: List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]], age_bins: pd.DataFrame, cutoff_weight: float = None) -> Dict[str, float]: if cutoff_weight: base_filter += ( QueryString('{column} <= {cutoff}') .format(column=f'`{self.birth_weight_pipeline_name}`', cutoff=cutoff_weight) ) measure = 'low_weight_births' else: measure = 'total_births' base_key = utilities.get_output_template(**configuration).substitute(measure=measure) births = {} for year, (year_start, year_end) in time_spans: year_filter = base_filter.format(start_time=year_start, end_time=year_end) year_key = base_key.substitute(year=year) group_births = utilities.get_group_counts(pop, year_filter, year_key, configuration, age_bins) births.update(group_births) return births
def on_collect_metrics(self, event): pop = self.population_view.get(event.index) current_lack_of_supplementation_exposure = self.lack_of_vitamin_a_supplementation( pop.index) # cat 1 represents the exposure to lack of vitamin a supplementation # therefore cat 2 is "being supplemented" current_supplemented_pop = pop.loc[ current_lack_of_supplementation_exposure == 'cat2'] config = self.config.to_dict().copy() base_filter = QueryString(f'alive == "alive"') base_key = get_output_template(**config).substitute( measure=self.measure_name, year=event.time.year) current_supplemented_count = get_group_counts(current_supplemented_pop, base_filter, base_key, config, self.age_bins) current_supplemented_count = { k: v * self.step_size().days for k, v in current_supplemented_count.items() } self.supplemented_days.update(current_supplemented_count)
def get_entrances(pop: pd.DataFrame, age_bins: pd.DataFrame, event_time: pd.Timestamp, evaluation_status: str) -> Dict[str, int]: mask = ((pop.alive == 'alive') & (pop.registry_evaluation_date == event_time) & (pop.registry_evaluation_status == evaluation_status)) config = {'by_age': True, 'by_sex': True, 'by_year': True} result = {} risk_template = '_'.join([ f'{s}_{{}}' for s in [ data_values.RISKS.race_and_cytogenetic_risk_at_diagnosis, data_values.RISKS.renal_function_at_diagnosis ] ]) for rcr, rf in itertools.product( data_values.RISK_LEVEL_MAP[ data_values.RISKS.race_and_cytogenetic_risk_at_diagnosis], data_values.RISK_LEVEL_MAP[ data_values.RISKS.renal_function_at_diagnosis]): risk_level_mask = ( (pop[data_values.RISKS.race_and_cytogenetic_risk_at_diagnosis] == rcr) & (pop[data_values.RISKS.renal_function_at_diagnosis] == rf) & mask) risk_level_pop = pop.loc[risk_level_mask] base_key = get_output_template(**config).substitute( measure=f'registry_status_newly_{evaluation_status}', year=event_time.year) group_counts = get_group_counts(risk_level_pop, "", base_key, config, age_bins) group_counts = { key + '_' + risk_template.format(rcr, rf): value for key, value in group_counts.items() } result.update(group_counts) return result
def on_collect_metrics(self, event: 'Event'): pop = self.population_view.get(event.index, query='alive == "alive"') initial_proportion_reduction = pop[ 'initial_treatment_proportion_reduction'] fpg = self.fpg(pop.index) sbp = self.sbp(pop.index) ldlc = self.ldlc(pop.index) cvd_score = self.cvd_risk_score(pop.index) measure_map = list( zip([ 'fpg_person_time', 'sbp_person_time', 'ldlc_person_time', 'cv_risk_score_person_time' ], [fpg, sbp, ldlc, cvd_score])) adherent = self.is_adherent(pop.index).astype(int) raw_ldlc = ldlc / (1 - initial_proportion_reduction) at_target = (ldlc / raw_ldlc <= 0.5).astype(int) # noinspection PyTypeChecker step_size = to_years(event.step_size) age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( self.config, self.age_bins) base_key = get_output_template(**self.config).substitute( year=event.time.year) base_filter = QueryString(f'alive == "alive"') + age_sex_filter person_time = {} for labels, pop_in_group in self.stratifier.group(pop): for group, age_group in ages: start, end = age_group.age_start, age_group.age_end for sex in sexes: filter_kwargs = { 'age_start': start, 'age_end': end, 'sex': sex, 'age_group': group } group_key = base_key.substitute(**filter_kwargs) group_filter = base_filter.format(**filter_kwargs) sub_pop = (pop_in_group.query(group_filter) if group_filter and not pop_in_group.empty else pop_in_group) for measure, attribute in measure_map: person_time[group_key.substitute( measure=measure)] = sum( attribute.loc[sub_pop.index] * step_size) adherent_pt = sum(adherent.loc[sub_pop.index] * step_size) person_time[group_key.substitute( measure='adherent_person_time')] = adherent_pt at_target_pt = sum(at_target.loc[sub_pop.index] * step_size) person_time[group_key.substitute( measure='at_target_person_time')] = at_target_pt treatments = { group_key.substitute( measure=f'{treatment}_person_time'): 0. for treatment in project_globals.TREATMENT } treatments.update( (sub_pop[project_globals.TREATMENT.name].map( lambda x: group_key.substitute( measure=f'{x}_person_time')).value_counts() * step_size).to_dict()) person_time.update(treatments) self.results.update( self.stratifier.update_labels(person_time, labels))