Ejemplo n.º 1
0
    def __init__(self, params):
        population_size = params.population_graph.number_of_nodes()

        # The action space is a population_size vector where each element takes on
        # values in [0, population_size).  Each element in the vector represents a
        # treatment (of which at most max_treatments can be given out at any one
        # timestep), and the value represents the index of the person who receives
        # the treatment.
        #
        # If None is passed instead of a vector, no treatment is administered.
        self.action_space = multi_discrete_with_none.MultiDiscreteWithNone([
            population_size for _ in range(params.max_treatments)
        ])  # type: spaces.Space

        # Define the spaces of observable state variables.
        self.observable_state_vars = {
            'health_states':
            spaces.MultiDiscrete(
                [len(params.state_names) for _ in range(population_size)]),
            'population_graph':
            graph.GraphSpace(population_size, directed=False),
        }  # type: Dict[Text, spaces.Space]

        # Map state names to indices.
        self.state_name_to_index = {
            state: i
            for i, state in enumerate(params.state_names)
        }

        super(InfectiousDiseaseEnv, self).__init__(params)
        self.state = self._create_initial_state()
 def test_base_and_with_none_agree(self):
     """MultiDiscrete and MultiDiscreteWithNone should agree about non-None."""
     nvec = [4, 4, 4]
     multi_discrete_space = multi_discrete.MultiDiscrete(nvec)
     multi_discrete_with_none_space = (
         multi_discrete_with_none.MultiDiscreteWithNone(nvec))
     for test_vec in ([-1, 1, 2], [2, 1, 0], [1, 2, 3]):
         self.assertEqual(multi_discrete_space.contains(test_vec),
                          multi_discrete_with_none_space.contains(test_vec))
 def test_none_can_be_sampled(self):
     space = multi_discrete_with_none.MultiDiscreteWithNone(
         nvec=[1, 2, 3, 4], none_probability=1)
     self.assertIsNone(space.sample())
 def test_space_contains_none(self):
     """The space should contain None."""
     space = multi_discrete_with_none.MultiDiscreteWithNone([1, 2, 3, 4])
     self.assertTrue(space.contains(None))