def test_undirected_space_not_equal_to_directed_space(self): num_nodes = 100 p = 0.05 directed_space = graph.GraphSpace(num_nodes, directed=True, p=p) undirected_space = graph.GraphSpace(num_nodes, directed=False, p=p) self.assertNotEqual(directed_space, undirected_space)
def __init__(self, params): population_size = params.population_graph.number_of_nodes() # The action space is a population_size vector where each element takes on # values in [0, population_size). Each element in the vector represents a # treatment (of which at most max_treatments can be given out at any one # timestep), and the value represents the index of the person who receives # the treatment. # # If None is passed instead of a vector, no treatment is administered. self.action_space = multi_discrete_with_none.MultiDiscreteWithNone([ population_size for _ in range(params.max_treatments) ]) # type: spaces.Space # Define the spaces of observable state variables. self.observable_state_vars = { 'health_states': spaces.MultiDiscrete( [len(params.state_names) for _ in range(population_size)]), 'population_graph': graph.GraphSpace(population_size, directed=False), } # type: Dict[Text, spaces.Space] # Map state names to indices. self.state_name_to_index = { state: i for i, state in enumerate(params.state_names) } super(InfectiousDiseaseEnv, self).__init__(params) self.state = self._create_initial_state()
def test_sampled_graphs_contain_correct_number_of_nodes(self): num_trials = 10 num_nodes = 100 p = 0.05 space = graph.GraphSpace(num_nodes, directed=False, p=p) for _ in range(num_trials): g = space.sample() self.assertEqual(g.number_of_nodes(), num_nodes)
def test_graph_space_contains_contact_network(self): population_size = 50 graph = nx.Graph() graph.add_nodes_from(range(population_size)) graph.add_edges_from(complete_graph_edge_list(population_size)) env = infectious_disease.build_si_model(population_graph=graph, infection_probability=1.0, num_treatments=0, max_treatments=10) space = graph_space.GraphSpace(population_size, directed=False) self.assertTrue(space.contains(env.state.population_graph))