Esempio n. 1
0
def test_isolate_vertex(line_g):
    gi = from_gt(line_g, None)
    isolate_vertex(gi, 0)
    assert set(edges(gi)) == {(2, 1), (1, 2)}

    isolate_vertex(gi, 1)
    assert set(edges(gi)) == set()
Esempio n. 2
0
def test_remove_vertex_node_index(disconnected_line_graph):
    gi = disconnected_line_graph
    isolate_vertex(gi, 0)
    assert set(vertices(gi)) == {0, 1, 2, 3, 4}

    assert reachable_vertices(gi, 0) == [0]
    assert reachable_vertices(gi, 1) == [1, 2]
    assert reachable_vertices(gi, 3) == [3, 4]
Esempio n. 3
0
def test_isolate_vertex_num_vertices():
    _, gi, _ = input_data_gt()
    prev_N = num_vertices(gi)
    isolate_vertex(gi, 0)
    nodes_with_edges = {u for e in edges(gi) for u in e}
    assert 0 not in nodes_with_edges

    assert prev_N == num_vertices(gi)
    isolate_vertex(gi, 1)
    assert prev_N == num_vertices(gi)
Esempio n. 4
0
def test_inf_probas_shape(g, gi, obs, with_inc_sampling):
    """might fail if the removed vertex isolates some observed nodes
    """
    error_estimator = TreeBasedStatistics(g)
    sampler = TreeSamplePool(g,
                             25,
                             'cut',
                             gi=gi,
                             return_type='nodes',
                             with_inc_sampling=with_inc_sampling)
    sampler.fill(obs)
    error_estimator.build_matrix(sampler.samples)

    n = g.num_vertices()
    all_nodes = extract_nodes(g)
    remaining_nodes = list(set(all_nodes) - set(obs))

    # remove five nodes
    removed = []
    for i in range(5):
        r = remaining_nodes[i]
        removed.append(r)

        observe_uninfected_node(g, r, obs)
        isolate_vertex(gi, r)

        # update samples
        new_samples = sampler.update_samples(obs, {r: 0})
        error_estimator.update_trees(new_samples, {r: 0})

        # check probas
        probas = error_estimator.unconditional_proba()

        assert probas.shape == (n, )
        for r in removed:
            assert probas[r] == 0
        for o in obs:
            assert probas[o] == 1.0
def test_TreeSamplePool_with_incremental_sampling(g, gi, obs, method, edge_weight):
    edge_weights = g.new_edge_property("float")
    edge_weights.set_value(edge_weight)  # if edge =1.0, for sure to include all nodes
    g.edge_properties['weights'] = edge_weights

    n_samples = 100
    sampler = TreeSamplePool(g, n_samples, method,
                             gi=gi,
                             return_type='nodes',
                             with_inc_sampling=True)

    sampler.fill(obs)

    assert len(sampler.samples) == n_samples

    for t in sampler.samples:
        assert isinstance(t, set)
        assert set(obs).issubset(t)
        if edge_weight == 1.0:
            # if edge weight is 1, all nodes are infected
            assert len(t) == g.num_vertices()

    # update
    n_rm = random.choice(
        list(set(np.arange(g.num_vertices())) - set(obs)))
    isolate_vertex(gi, n_rm)
    observe_uninfected_node(g, n_rm, obs)

    print('n_rm', n_rm)
    print('n_rm.out_edges()', list(g.vertex(n_rm).out_edges()))
    print('n_rm.in_edges()', list(g.vertex(n_rm).in_edges()))
    edges = {e for e in gi_edges(gi) if n_rm in set(e)}
    print('gi.vertex(n_rm).edges()', edges)

    num_invalid_trees = sum(1 for t in sampler.samples if n_rm in t)
    valid_trees = [t
                   for t in sampler.samples
                   if n_rm not in t]  # this tree cannot be changed even after .update
    valid_trees_old = copy(valid_trees)

    new_samples = sampler.update_samples(obs, {n_rm: 0})

    assert len(sampler.samples) == n_samples

    assert len(new_samples) == num_invalid_trees

    for t in new_samples:
        # new samples are also incremented
        assert isinstance(t, set)
        assert set(obs).issubset(t)
        if edge_weight == 1.0:
            assert len(t) == (g.num_vertices() - 1)  # because of noden isolation, now it's 99
        else:
            assert len(t) < (g.num_vertices() - 1)

    for t in sampler.samples:
        assert n_rm not in t  # because n_rm is removed

    # make sure valid trees before and after update remaint the same
    for t1, t2 in zip(valid_trees, valid_trees_old):
        assert t1 == t2
    def run(self,
            n_queries,
            obs=None,
            c=None,
            gen_input_kwargs={},
            iter_callback=None):
        """return the list of query nodes
        """
        if obs is None or c is None:
            obs, c = gen_input(self.g, **gen_input_kwargs)[:2]

        self.q_gen.receive_observation(obs, c)

        aux = {'graph_changed': False, 'obs': obs, 'c': c}
        qs = []
        inf_nodes = list(obs)
        uninf_nodes = []

        if self.print_log:
            iters = tqdm(range(n_queries), total=n_queries)
        else:
            iters = range(n_queries)

        for i in iters:
            try:
                q = self.q_gen.select_query(self.g, inf_nodes)
            except NoMoreQuery:
                if self.print_log:
                    print('no more nodes to query. queried {} nodes'.format(
                        len(qs)))
                break

            # print('query:', q)
            qs.append(q)

            if len(qs) == n_queries:
                print('num. queries reached')
                break

            if c[q] == -1:  # not infected
                if self.print_log:
                    # print('isolating node {} started'.format(q))
                    pass

                observe_uninfected_node(self.g, q, inf_nodes)
                if self.gi is not None:
                    isolate_vertex(self.gi, q)

                if self.print_log:
                    # print('isolating node {} done'.format(q))
                    pass

                self.q_gen.update_pool(self.g)
                aux['graph_changed'] = True
                uninf_nodes.append(q)
            else:
                inf_nodes.append(q)

            # update tree samples if necessary
            if self.print_log:
                print('update samples started')
                pass

            label = int(c[q] >= 0)
            assert label in {0, 1}
            # print('update samples, node {} label {}'.format(q, label))
            try:
                self.q_gen.update_observation(self.g, inf_nodes, q, label, c)
            except NoMoreQuery:
                print('no more queries')
                break

            if self.print_log:
                print('update samples done')

            if callable(iter_callback):
                iter_callback(self.g, self.q_gen, inf_nodes, uninf_nodes)

        return qs, aux