Esempio n. 1
0
  def test_undirected_space_not_equal_to_directed_space(self):
    num_nodes = 100
    p = 0.05

    directed_space = graph.GraphSpace(num_nodes, directed=True, p=p)
    undirected_space = graph.GraphSpace(num_nodes, directed=False, p=p)

    self.assertNotEqual(directed_space, undirected_space)
Esempio n. 2
0
    def __init__(self, params):
        population_size = params.population_graph.number_of_nodes()

        # The action space is a population_size vector where each element takes on
        # values in [0, population_size).  Each element in the vector represents a
        # treatment (of which at most max_treatments can be given out at any one
        # timestep), and the value represents the index of the person who receives
        # the treatment.
        #
        # If None is passed instead of a vector, no treatment is administered.
        self.action_space = multi_discrete_with_none.MultiDiscreteWithNone([
            population_size for _ in range(params.max_treatments)
        ])  # type: spaces.Space

        # Define the spaces of observable state variables.
        self.observable_state_vars = {
            'health_states':
            spaces.MultiDiscrete(
                [len(params.state_names) for _ in range(population_size)]),
            'population_graph':
            graph.GraphSpace(population_size, directed=False),
        }  # type: Dict[Text, spaces.Space]

        # Map state names to indices.
        self.state_name_to_index = {
            state: i
            for i, state in enumerate(params.state_names)
        }

        super(InfectiousDiseaseEnv, self).__init__(params)
        self.state = self._create_initial_state()
Esempio n. 3
0
  def test_sampled_graphs_contain_correct_number_of_nodes(self):
    num_trials = 10
    num_nodes = 100
    p = 0.05
    space = graph.GraphSpace(num_nodes, directed=False, p=p)

    for _ in range(num_trials):
      g = space.sample()
      self.assertEqual(g.number_of_nodes(), num_nodes)
Esempio n. 4
0
    def test_graph_space_contains_contact_network(self):
        population_size = 50

        graph = nx.Graph()
        graph.add_nodes_from(range(population_size))
        graph.add_edges_from(complete_graph_edge_list(population_size))
        env = infectious_disease.build_si_model(population_graph=graph,
                                                infection_probability=1.0,
                                                num_treatments=0,
                                                max_treatments=10)

        space = graph_space.GraphSpace(population_size, directed=False)

        self.assertTrue(space.contains(env.state.population_graph))