def scenario_builder(self):
        """Returns an agent and environment pair."""
        graph = GRAPHS[self.graph_name]

        env = infectious_disease.build_sir_model(
            population_graph=graph,
            infection_probability=self.infection_probability,
            infected_exit_probability=self.infected_exit_probability,
            num_treatments=self.num_treatments,
            max_treatments=1,
            burn_in=self.burn_in,
            # Treatments turn susceptible people into recovered without having them
            # get sick.
            treatment_transition_matrix=np.array([[0, 0, 1], [0, 1, 0],
                                                  [0, 0, 1]]),
            # Everybody starts out healthy.
            initial_health_state=[0] * graph.number_of_nodes(),
            initial_health_state_seed=self.env_seed)

        agent = self.agent_constructor(
            env.action_space,
            rewards.NullReward(),
            env.observation_space,
            params=infectious_disease_agents.env_to_agent_params(
                env.initial_params))

        return env, agent
Example #2
0
    def test_disease_progresses_with_contact_sir(self):
        num_steps = 10
        population_size = 5

        # Set up a population that is well-connected (here, totally connected).
        graph = nx.Graph()
        graph.add_nodes_from(range(population_size))
        graph.add_edges_from(complete_graph_edge_list(population_size))
        env = infectious_disease.build_sir_model(
            population_graph=graph,
            infection_probability=1.0,
            infected_exit_probability=0.0,
            num_treatments=0,
            max_treatments=10,
            initial_health_state=[
                0 if i % 2 == 0 else 1 for i in range(graph.number_of_nodes())
            ])
        agent = random_agents.RandomAgent(env.action_space, lambda x: 0,
                                          env.observation_space)
        initial_state = copy.deepcopy(env.state)

        # Ensure that there are more infected people after running the simulation
        # for some time.
        test_util.run_test_simulation(env=env,
                                      agent=agent,
                                      num_steps=num_steps)
        self.assertGreater(
            num_in_health_state(env.state,
                                env.state_name_to_index['infected']),
            num_in_health_state(initial_state,
                                env.state_name_to_index['infected']))
Example #3
0
    def test_observation_is_up_to_date(self):
        """Tests that the observation reflects the population that will be treated.

    Checks that the treatment is applied BEFORE disease spread at every step,
    not after.
    """
        seed = 1
        population_size = 25

        # Start everyone in 'susceptible' except for patient 0 who is infected.
        initial_health_state = [0 for _ in range(population_size)]
        initial_health_state[0] = 1

        # Fully connected contact graph.
        graph = nx.Graph()
        graph.add_nodes_from(range(population_size))
        graph.add_edges_from(complete_graph_edge_list(population_size))

        # Treatments make infected people recover,
        treatment_transition_matrix = np.array([[1, 0, 0], [0, 0, 1],
                                                [0, 0, 1]])

        env = infectious_disease.build_sir_model(
            population_graph=graph,
            infection_probability=1.0,
            infected_exit_probability=0.0,
            treatment_transition_matrix=treatment_transition_matrix,
            num_treatments=1,
            max_treatments=1,
            initial_health_state=initial_health_state)

        env.set_scalar_reward(lambda x: 0)
        env.seed(seed)
        _ = env.reset()
        observation, _, _, _ = env.step([0])

        # Check that the treatment got there before the disease spread. If it got
        # there afterward, many more people would be sick.
        expected_health = [2] + [0] * (population_size - 1)
        self.assertEqual(observation['health_states'].tolist(),
                         expected_health)