コード例 #1
0
ファイル: observers.py プロジェクト: ihmeuw/vivarium_ciff_sam
    def _metrics(self, index: pd.Index, metrics: Dict) -> Dict:
        pipelines = [
            pd.Series(pipeline(index), name=pipeline_name)
            for pipeline_name, pipeline in self.pipelines.items()
        ]
        pop = pd.concat([self.population_view.get(index)] + pipelines, axis=1)

        measure_getters = (
            (self._get_births, ()),
            (self._get_birth_weight_sum, ()),
            (self._get_births, (results.LOW_BIRTH_WEIGHT_CUTOFF,)),
        )

        config_dict = self.configuration.to_dict()
        base_filter = QueryString(f'"{{start_time}}" <= {self.entrance_time_column_name} '
                                  f'and {self.entrance_time_column_name} < "{{end_time}}"')
        time_spans = utilities.get_time_iterable(config_dict, self.start_time, self.clock())

        for labels, pop_in_group in self.stratifier.group(pop):
            args = (pop_in_group, base_filter, self.configuration.to_dict(), time_spans, self.age_bins)

            for measure_getter, extra_args in measure_getters:
                measure_data = measure_getter(*args, *extra_args)
                measure_data = self.stratifier.update_labels(measure_data, labels)
                metrics.update(measure_data)

        return metrics
コード例 #2
0
def test_query_string(a, b):
    assert a + b == 'a and b'
    assert a + b == QueryString('a and b')
    assert isinstance(a + b, QueryString)

    assert b + a == 'b and a'
    assert b + a == QueryString('b and a')
    assert isinstance(b + a, QueryString)

    a += b
    assert a == 'a and b'
    assert a == QueryString('a and b')
    assert isinstance(a, QueryString)

    b += a
    assert b == 'b and a and b'
    assert b == QueryString('b and a and b')
    assert isinstance(b, QueryString)
コード例 #3
0
def test_query_string_empty(reference, test):
    result = str(reference)
    assert reference + test == result
    assert reference + test == QueryString(result)
    assert isinstance(reference + test, QueryString)

    assert test + reference == result
    assert test + reference == QueryString(result)
    assert isinstance(test + reference, QueryString)

    reference += test
    assert reference == result
    assert reference == QueryString(result)
    assert isinstance(reference, QueryString)

    test += reference
    assert test == result
    assert test == QueryString(result)
    assert isinstance(test, QueryString)
コード例 #4
0
ファイル: observers.py プロジェクト: ihmeuw/vivarium_ciff_sam
    def _get_births(self, pop: pd.DataFrame, base_filter: QueryString, configuration: Dict,
                    time_spans: List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]], age_bins: pd.DataFrame,
                    cutoff_weight: float = None) -> Dict[str, float]:
        if cutoff_weight:
            base_filter += (
                QueryString('{column} <= {cutoff}')
                .format(column=f'`{self.birth_weight_pipeline_name}`', cutoff=cutoff_weight)
            )
            measure = 'low_weight_births'
        else:
            measure = 'total_births'

        base_key = utilities.get_output_template(**configuration).substitute(measure=measure)

        births = {}
        for year, (year_start, year_end) in time_spans:
            year_filter = base_filter.format(start_time=year_start, end_time=year_end)
            year_key = base_key.substitute(year=year)
            group_births = utilities.get_group_counts(pop, year_filter, year_key, configuration, age_bins)
            births.update(group_births)
        return births
コード例 #5
0
ファイル: observers.py プロジェクト: ihmeuw/vivarium_ciff_sam
    def _get_birth_weight_sum(self, pop: pd.DataFrame, base_filter: QueryString, configuration: Dict,
                              time_spans: List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]],
                              age_bins: pd.DataFrame) -> Dict[str, float]:

        base_key = utilities.get_output_template(**configuration).substitute(measure='birth_weight_sum')

        birth_weight_sum = {}
        for year, (year_start, year_end) in time_spans:
            year_filter = base_filter.format(start_time=year_start, end_time=year_end)
            year_key = base_key.substitute(year=year)
            group_birth_weight_sums = utilities.get_group_counts(pop, year_filter, year_key, configuration, age_bins,
                                                                 lambda df: df[self.birth_weight_pipeline_name].sum())
            birth_weight_sum.update(group_birth_weight_sums)
        return birth_weight_sum
コード例 #6
0
def get_person_time(pop: pd.DataFrame, config: Dict[str, bool],
                    current_year: Union[str, int], step_size: pd.Timedelta,
                    age_bins: pd.DataFrame) -> Dict[str, float]:
    base_key = get_output_template(**config).substitute(measure='person_time',
                                                        year=current_year)
    base_filter = QueryString(f'alive == "alive"')
    person_time = get_group_counts(
        pop,
        base_filter,
        base_key,
        config,
        age_bins,
        aggregate=lambda x: len(x) * to_years(step_size))
    return person_time
コード例 #7
0
def get_transition_count(pop: pd.DataFrame, config: Dict[str,
                                                         bool], disease: str,
                         transition: project_globals.TransitionString,
                         event_time: pd.Timestamp,
                         age_bins: pd.DataFrame) -> Dict[str, float]:
    """Counts transitions that occurred this step."""
    event_this_step = ((pop[f'previous_{disease}'] == transition.from_state)
                       & (pop[disease] == transition.to_state))
    transitioned_pop = pop.loc[event_this_step]
    base_key = get_output_template(**config).substitute(
        measure=f'{transition}_event_count', year=event_time.year)
    base_filter = QueryString('')
    transition_count = get_group_counts(transitioned_pop, base_filter,
                                        base_key, config, age_bins)
    return transition_count
コード例 #8
0
def get_births(pop: pd.DataFrame, config: Dict[str,
                                               bool], sim_start: pd.Timestamp,
               sim_end: pd.Timestamp) -> Dict[str, int]:
    """Counts the number of births and births with neural tube defects prevelant.
    Parameters
    ----------
    pop
        The population dataframe to be counted. It must contain sufficient
        columns for any necessary filtering (e.g. the ``age`` column if
        filtering by age).
    config
        A dict with ``by_age``, ``by_sex``, and ``by_year`` keys and
        boolean values.
    sim_start
        The simulation start time.
    sim_end
        The simulation end time.
    Returns
    -------
    births
        All births and births with neural tube defects present.
    """
    base_filter = QueryString('')
    base_key = get_output_template(**config)
    time_spans = get_time_iterable(config, sim_start, sim_end)

    births = {}
    for year, (t_start, t_end) in time_spans:
        start = max(sim_start, t_start)
        end = min(sim_end, t_end)
        born_in_span = pop.query(
            f'"{start}" <= entrance_time and entrance_time < "{end}"')

        cat_year_key = base_key.substitute(measure='live_births', year=year)
        group_births = get_group_counts(born_in_span, base_filter,
                                        cat_year_key, config, pd.DataFrame())
        births.update(group_births)

        cat_year_key = base_key.substitute(measure='born_with_ntds', year=year)
        filter_update = f'{project_globals.NTD_MODEL_NAME} == "{project_globals.NTD_MODEL_NAME}"'
        empty_age_bins = pd.DataFrame()
        group_ntd_births = get_group_counts(born_in_span,
                                            base_filter + filter_update,
                                            cat_year_key, config,
                                            empty_age_bins)
        births.update(group_ntd_births)
    return births
コード例 #9
0
def get_state_person_time(pop: pd.DataFrame, config: Dict[str, bool],
                          disease: str, state: str, current_year: Union[str,
                                                                        int],
                          step_size: pd.Timedelta,
                          age_bins: pd.DataFrame) -> Dict[str, float]:
    """Custom person time getter that handles state column name assumptions"""
    base_key = get_output_template(**config).substitute(
        measure=f'{state}_person_time', year=current_year)
    base_filter = QueryString(f'alive == "alive" and {disease} == "{state}"')
    person_time = get_group_counts(
        pop,
        base_filter,
        base_key,
        config,
        age_bins,
        aggregate=lambda x: len(x) * to_years(step_size))
    return person_time
コード例 #10
0
def test_get_person_time_in_span(ages_and_bins, observer_config):
    _, age_bins = ages_and_bins
    start = int(age_bins.age_start.min())
    end = int(age_bins.age_end.max())
    n_ages = len(list(range(start, end)))
    n_bins = len(age_bins)
    segments_per_age = [(i + 1) * (n_ages - i) for i in range(n_ages)]
    ages_per_bin = n_ages // n_bins
    age_bins['expected_time'] = [
        sum(segments_per_age[ages_per_bin * i:ages_per_bin * (i + 1)])
        for i in range(n_bins)
    ]

    age_starts, age_ends = zip(*combinations(range(start, end + 1), 2))
    women = pd.DataFrame({
        'age_at_span_start': age_starts,
        'age_at_span_end': age_ends,
        'sex': 'Female'
    })
    men = women.copy()
    men.loc[:, 'sex'] = 'Male'

    lived_in_span = pd.concat(
        [women, men], ignore_index=True).sample(frac=1).reset_index(drop=True)
    base_filter = QueryString("")
    span_key = get_output_template(**observer_config).substitute(
        measure='person_time', year=2019)

    pt = get_person_time_in_span(lived_in_span, base_filter, span_key,
                                 observer_config, age_bins)

    if observer_config['by_age']:
        for group, age_bin in age_bins.iterrows():
            group_pt = sum(
                set([v for k, v in pt.items()
                     if f'in_age_group_{group}' in k]))
            if observer_config['by_sex']:
                assert group_pt == age_bin.expected_time
            else:
                assert group_pt == 2 * age_bin.expected_time
    else:
        group_pt = sum(set(pt.values()))
        if observer_config['by_sex']:
            assert group_pt == age_bins.expected_time.sum()
        else:
            assert group_pt == 2 * age_bins.expected_time.sum()
コード例 #11
0
 def on_time_step_prepare(self, event):
     # we count person time each time step if we are tracking WHZ
     pop = self.population_view.get(event.index)
     raw_whz_exposure = self.raw_whz_exposure(event.index)
     whz_exposure = convert_whz_to_categorical(raw_whz_exposure)
     for cat in whz_exposure.unique():
         in_cat = pop.loc[whz_exposure == cat]
         base_filter = QueryString('alive == "alive"')
         base_key = get_output_template(**self.config)
         base_key = base_key.substitute(measure='person_time',
                                        year=self.clock().year)
         counts = get_group_counts(in_cat, base_filter, base_key,
                                   self.config.to_dict(), self.age_bins)
         counts = {
             str(key) + f'_in_{cat}': value * self.step_size
             for key, value in counts.items()
         }
         self.person_time.update(counts)
コード例 #12
0
 def on_time_step_prepare(self, event: 'Event'):
     pop = self.population_view.get(event.index)
     pop['anemia'] = self.anemia_severity(pop.index)
     # Ignoring the edge case where the step spans a new year.
     # Accrue all counts and time to the current year.
     for state in self.states:
         base_key = get_output_template(**self.config).substitute(
             measure=f'anemia_{state}_person_time', year=self.clock().year)
         base_filter = QueryString(
             f'alive == "alive" and anemia == "{state}"')
         # noinspection PyTypeChecker
         person_time = get_group_counts(
             pop,
             base_filter,
             base_key,
             self.config,
             self.age_bins,
             aggregate=lambda x: len(x) * to_years(event.step_size))
         self.person_time.update(person_time)
コード例 #13
0
    def on_time_step_prepare(self, event):
        # I think this is right timing wise - I didn't want to do on collect metrics b/c if someone gets on tx during
        # a time step, it doesn't seem like their person time should be counted in the treated status
        base_filter = QueryString("")

        for key, index in self.get_groups(event.index).items():
            pop = self.population_view.get(index)
            pop.loc[pop.exit_time.isna(),
                    'exit_time'] = self.clock() + self.step_size

            t_start = self.clock()
            t_end = self.clock() + self.step_size
            lived_in_span = get_lived_in_span(pop, t_start, t_end)

            span_key = get_output_template(**self.config.to_dict()).substitute(
                measure=f'person_time_{key}')
            person_time_in_span = get_person_time_in_span(
                lived_in_span, base_filter, span_key, self.config.to_dict(),
                self.age_bins)
            self.person_time.update(person_time_in_span)
コード例 #14
0
    def on_collect_metrics(self, event):
        base_key = get_output_template(**self.config).substitute(
            year=event.time.year)
        base_filter = QueryString('')

        pop = self.population_view.get(event.index)
        pop = pop.loc[pop[f'vaccine_event_time'] == event.time]

        dose_counts = {}
        for dose in project_globals.VACCINE_DOSES:
            dose_filter = base_filter + f'vaccine_dose == "{dose}"'
            group_counts = get_group_counts(pop, dose_filter, base_key,
                                            self.config, self.age_bins)
            for group_key, count in group_counts.items():
                group_key = group_key.substitute(
                    measure=
                    f'{project_globals.SHIGELLA_VACCINE}_{dose}_dose_count')
                dose_counts[group_key] = count

        self.counts.update(dose_counts)
コード例 #15
0
    def on_collect_metrics(self, event):
        pop = self.population_view.get(event.index)
        current_lack_of_supplementation_exposure = self.lack_of_vitamin_a_supplementation(
            pop.index)

        # cat 1 represents the exposure to lack of vitamin a supplementation
        # therefore cat 2 is "being supplemented"
        current_supplemented_pop = pop.loc[
            current_lack_of_supplementation_exposure == 'cat2']

        config = self.config.to_dict().copy()
        base_filter = QueryString(f'alive == "alive"')
        base_key = get_output_template(**config).substitute(
            measure=self.measure_name, year=event.time.year)
        current_supplemented_count = get_group_counts(current_supplemented_pop,
                                                      base_filter, base_key,
                                                      config, self.age_bins)
        current_supplemented_count = {
            k: v * self.step_size().days
            for k, v in current_supplemented_count.items()
        }
        self.supplemented_days.update(current_supplemented_count)
コード例 #16
0
@pytest.fixture()
def builder(mocker):
    builder = mocker.MagicMock()
    df = pd.DataFrame({
        'age_start': [0, 1, 4],
        'age_group_name': ['youngest', 'younger', 'young'],
        'age_end': [1, 4, 6]
    })
    builder.data.load.return_value = df
    return builder


@pytest.mark.parametrize(
    'reference, test',
    product([QueryString(''), QueryString('abc')], [QueryString(''), '']))
def test_query_string_empty(reference, test):
    result = str(reference)
    assert reference + test == result
    assert reference + test == QueryString(result)
    assert isinstance(reference + test, QueryString)

    assert test + reference == result
    assert test + reference == QueryString(result)
    assert isinstance(test + reference, QueryString)

    reference += test
    assert reference == result
    assert reference == QueryString(result)
    assert isinstance(reference, QueryString)
コード例 #17
0
    def on_collect_metrics(self, event: 'Event'):
        pop = self.population_view.get(event.index, query='alive == "alive"')
        initial_proportion_reduction = pop[
            'initial_treatment_proportion_reduction']

        fpg = self.fpg(pop.index)
        sbp = self.sbp(pop.index)
        ldlc = self.ldlc(pop.index)
        cvd_score = self.cvd_risk_score(pop.index)
        measure_map = list(
            zip([
                'fpg_person_time', 'sbp_person_time', 'ldlc_person_time',
                'cv_risk_score_person_time'
            ], [fpg, sbp, ldlc, cvd_score]))

        adherent = self.is_adherent(pop.index).astype(int)

        raw_ldlc = ldlc / (1 - initial_proportion_reduction)
        at_target = (ldlc / raw_ldlc <= 0.5).astype(int)

        # noinspection PyTypeChecker
        step_size = to_years(event.step_size)

        age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables(
            self.config, self.age_bins)
        base_key = get_output_template(**self.config).substitute(
            year=event.time.year)
        base_filter = QueryString(f'alive == "alive"') + age_sex_filter
        person_time = {}
        for labels, pop_in_group in self.stratifier.group(pop):
            for group, age_group in ages:
                start, end = age_group.age_start, age_group.age_end
                for sex in sexes:
                    filter_kwargs = {
                        'age_start': start,
                        'age_end': end,
                        'sex': sex,
                        'age_group': group
                    }
                    group_key = base_key.substitute(**filter_kwargs)
                    group_filter = base_filter.format(**filter_kwargs)

                    sub_pop = (pop_in_group.query(group_filter) if group_filter
                               and not pop_in_group.empty else pop_in_group)

                    for measure, attribute in measure_map:
                        person_time[group_key.substitute(
                            measure=measure)] = sum(
                                attribute.loc[sub_pop.index] * step_size)

                    adherent_pt = sum(adherent.loc[sub_pop.index] * step_size)
                    person_time[group_key.substitute(
                        measure='adherent_person_time')] = adherent_pt

                    at_target_pt = sum(at_target.loc[sub_pop.index] *
                                       step_size)
                    person_time[group_key.substitute(
                        measure='at_target_person_time')] = at_target_pt

                    treatments = {
                        group_key.substitute(
                            measure=f'{treatment}_person_time'): 0.
                        for treatment in project_globals.TREATMENT
                    }

                    treatments.update(
                        (sub_pop[project_globals.TREATMENT.name].map(
                            lambda x: group_key.substitute(
                                measure=f'{x}_person_time')).value_counts() *
                         step_size).to_dict())
                    person_time.update(treatments)

            self.results.update(
                self.stratifier.update_labels(person_time, labels))