def test_side_effects(): class DoneState(State): @uses_columns(['count']) def _transition_side_effect(self, index, population_view): pop = population_view.get(index) population_view.update(pop['count'] + 1) done_state = DoneState('done') start_state = State('start') done_transition = Transition(done_state, lambda agents: np.full(len(agents), 1.0)) start_state.transition_set.append(done_transition) done_state.transition_set.append(done_transition) machine = Machine('state') machine.states.extend([start_state, done_state]) simulation = setup_simulation([ machine, _population_fixture('state', 'start'), _population_fixture('count', 0) ]) machine.transition(simulation.population.population.index) assert np.all(simulation.population.population['count'] == 1) machine.transition(simulation.population.population.index) assert np.all(simulation.population.population['count'] == 2)
def test_interpolated_tables__exact_values_at_input_points(): years = build_table(lambda age, sex, year: year) input_years = years.year.unique() simulation = setup_simulation([generate_test_population], 10000) manager = simulation.tables years = manager.build_table(years) for year in input_years: simulation.current_time = datetime(year=year, month=1, day=1) assert np.allclose(years(simulation.population.population.index), simulation.current_time.year + 1/365, rtol=1.e-5)
def test_transition(): done_state = State('done') start_state = State('start') done_transition = Transition(done_state, lambda agents: np.full(len(agents), 1.0)) start_state.transition_set.append(done_transition) machine = Machine('state') machine.states.extend([start_state, done_state]) simulation = setup_simulation( [machine, _population_fixture('state', 'start')]) machine.transition(simulation.population.population.index) assert np.all(simulation.population.population.state == 'done')
def test_null_transition(): a_state = State('a') start_state = State('start') start_state.add_transition( a_state, probability_func=lambda agents: np.full(len(agents), 0.5)) start_state.allow_self_transitions() machine = Machine('state', states=[start_state, a_state]) simulation = setup_simulation( [machine, _population_fixture('state', 'start')], population_size=10000) machine.transition(simulation.population.population.index) a_count = (simulation.population.population.state == 'a').sum() assert round(a_count / len(simulation.population.population), 1) == 0.5
def test_no_null_transition(): a_state = State('a') b_state = State('b') start_state = State('start') a_transition = Transition(a_state) b_transition = Transition(b_state) start_state.transition_set.allow_null_transition = False start_state.transition_set.extend((a_transition, b_transition)) machine = Machine('state') machine.states.extend([start_state, a_state, b_state]) simulation = setup_simulation( [machine, _population_fixture('state', 'start')], population_size=10000) machine.transition(simulation.population.population.index) a_count = (simulation.population.population.state == 'a').sum() assert round(a_count / len(simulation.population.population), 1) == 0.5
def test_choice(): a_state = State('a') b_state = State('b') start_state = State('start') a_transition = Transition(a_state, lambda agents: np.full(len(agents), 0.5)) b_transition = Transition(b_state, lambda agents: np.full(len(agents), 0.5)) start_state.transition_set.extend((a_transition, b_transition)) machine = Machine('state') machine.states.extend([start_state, a_state, b_state]) simulation = setup_simulation( [machine, _population_fixture('state', 'start')], population_size=10000) machine.transition(simulation.population.population.index) a_count = (simulation.population.population.state == 'a').sum() assert round(a_count / len(simulation.population.population), 1) == 0.5
def test_interpolated_tables(): years = build_table(lambda age, sex, year: year) ages = build_table(lambda age, sex, year: age) one_d_age = ages.copy() del one_d_age['year'] one_d_age = one_d_age.drop_duplicates() simulation = setup_simulation([generate_test_population], 10000) manager = simulation.tables years = manager.build_table(years) ages = manager.build_table(ages) one_d_age = manager.build_table(one_d_age, parameter_columns=('age',)) result_years = years(simulation.population.population.index) result_ages = ages(simulation.population.population.index) result_ages_1d = one_d_age(simulation.population.population.index) fractional_year = simulation.current_time.year fractional_year += simulation.current_time.timetuple().tm_yday / 365.25 assert np.allclose(result_years, fractional_year) assert np.allclose(result_ages, simulation.population.population.age) assert np.allclose(result_ages_1d, simulation.population.population.age) simulation.current_time += timedelta(days=30.5 * 125) simulation.population._population.age += 125/12 result_years = years(simulation.population.population.index) result_ages = ages(simulation.population.population.index) result_ages_1d = one_d_age(simulation.population.population.index) fractional_year = simulation.current_time.year fractional_year += simulation.current_time.timetuple().tm_yday / 365.25 assert np.allclose(result_years, fractional_year) assert np.allclose(result_ages, simulation.population.population.age) assert np.allclose(result_ages_1d, simulation.population.population.age)
def test_interpolated_tables_without_uniterpolated_columns(): years = build_table(lambda age, sex, year: year) del years['sex'] years = years.drop_duplicates() simulation = setup_simulation([generate_test_population], 10000) manager = simulation.tables years = manager.build_table(years, key_columns=(), parameter_columns=('year', 'age',)) result_years = years(simulation.population.population.index) fractional_year = simulation.current_time.year fractional_year += simulation.current_time.timetuple().tm_yday / 365.25 assert np.allclose(result_years, fractional_year) simulation.current_time += timedelta(days=30.5 * 125) result_years = years(simulation.population.population.index) fractional_year = simulation.current_time.year fractional_year += simulation.current_time.timetuple().tm_yday / 365.25 assert np.allclose(result_years, fractional_year)