def __init__(self, params): population_size = params.population_graph.number_of_nodes() # The action space is a population_size vector where each element takes on # values in [0, population_size). Each element in the vector represents a # treatment (of which at most max_treatments can be given out at any one # timestep), and the value represents the index of the person who receives # the treatment. # # If None is passed instead of a vector, no treatment is administered. self.action_space = multi_discrete_with_none.MultiDiscreteWithNone([ population_size for _ in range(params.max_treatments) ]) # type: spaces.Space # Define the spaces of observable state variables. self.observable_state_vars = { 'health_states': spaces.MultiDiscrete( [len(params.state_names) for _ in range(population_size)]), 'population_graph': graph.GraphSpace(population_size, directed=False), } # type: Dict[Text, spaces.Space] # Map state names to indices. self.state_name_to_index = { state: i for i, state in enumerate(params.state_names) } super(InfectiousDiseaseEnv, self).__init__(params) self.state = self._create_initial_state()
def test_base_and_with_none_agree(self): """MultiDiscrete and MultiDiscreteWithNone should agree about non-None.""" nvec = [4, 4, 4] multi_discrete_space = multi_discrete.MultiDiscrete(nvec) multi_discrete_with_none_space = ( multi_discrete_with_none.MultiDiscreteWithNone(nvec)) for test_vec in ([-1, 1, 2], [2, 1, 0], [1, 2, 3]): self.assertEqual(multi_discrete_space.contains(test_vec), multi_discrete_with_none_space.contains(test_vec))
def test_none_can_be_sampled(self): space = multi_discrete_with_none.MultiDiscreteWithNone( nvec=[1, 2, 3, 4], none_probability=1) self.assertIsNone(space.sample())
def test_space_contains_none(self): """The space should contain None.""" space = multi_discrete_with_none.MultiDiscreteWithNone([1, 2, 3, 4]) self.assertTrue(space.contains(None))