示例#1
0
class TestEdgeSplitterCommon(object):

    g = create_heterogeneous_graph()

    es_obj = EdgeSplitter(g)

    def test_split_data_p_parameter(self):
        # Test some edge cases for the value of p, e.g., < 0, = 0, > 1, =1
        p = 0
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p, method="global")

        p = -0.1
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p, method="global")

        p = 1.001
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p, method="global")

        p = 1
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p, method="global")

    def test_split_data_method_parameter(self):
        p = 0.5  # any value in the interval (0, 1) should do
        sampling_method = "other"  # correct values are global and local only
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p, method=sampling_method)
def heterogeneous_graph():
    # TODO: We test if this graph is connected but there is no guarantee of connectivity in this code
    g = nx.Graph()

    random.seed(152)  # produces the same graph every time

    start_date_dt = datetime.strptime("01/01/2015", "%d/%m/%Y")
    end_date_dt = datetime.strptime("01/01/2017", "%d/%m/%Y")
    start_end_days = (
        end_date_dt - start_date_dt
    ).days  # the number of days between start and end dates

    # 50 nodes of type person
    person_node_ids = list(range(0, 50))
    for person in person_node_ids:
        g.add_node(person, label="person", elite=random.choice([0, 1]))

    # 200 nodes of type paper
    paper_node_ids = list(range(50, 250))
    g.add_nodes_from(paper_node_ids, label="paper")

    # 10 nodes of type venue
    venue_node_ids = list(range(250, 260))
    g.add_nodes_from(venue_node_ids, label="venue")

    # add the person - friend -> person edges
    # each person can be friends with 0 to 5 others; edges include a date
    for person_id in person_node_ids:
        k = random.randrange(5)
        friend_ids = set(random.sample(person_node_ids, k=k)) - {
            person_id
        }  # no self loops
        for friend in friend_ids:
            g.add_edge(
                person_id,
                friend,
                label="friend",
                date=(
                    start_date_dt + timedelta(days=random.randrange(start_end_days))
                ).strftime("%d/%m/%Y"),
            )

    # add the person - writes -> paper edges
    for person_id in person_node_ids:
        k = random.randrange(5)
        paper_ids = random.sample(paper_node_ids, k=k)
        for paper in paper_ids:
            g.add_edge(person_id, paper, label="writes")

    # add the paper - published-at -> venue edges
    for paper_id in paper_node_ids:
        venue_id = random.sample(venue_node_ids, k=1)[
            0
        ]  # paper is published at 1 venue only
        g.add_edge(paper_id, venue_id, label="published-at")

    return g, EdgeSplitter(g)
    def test_stellargraph(self):
        original_graph = example_graph_random(n_nodes=20, n_edges=50)

        directed_edges = original_graph.edges()
        true_edges = set(directed_edges) | {(tgt, src) for src, tgt in directed_edges}

        def check(split_graph, ids, labels):
            assert isinstance(split_graph, StellarGraph)

            nodes = split_graph.nodes()

            # the node features should be propagated correctly
            np.testing.assert_array_equal(
                split_graph.node_features(nodes), original_graph.node_features(nodes),
            )

            # the sampled edge labels should match the ground truth
            for (src, dst), label in zip(ids, labels):
                if label:
                    assert (src, dst) in true_edges
                else:
                    assert (src, dst) not in true_edges

        es = EdgeSplitter(original_graph)

        split1, ids1, labels1 = es.train_test_split(p=0.5, method="global")
        check(split1, ids1, labels1)

        # validate passing a StellarGraph as g_master
        es_master = EdgeSplitter(split1, original_graph)
        split2, ids2, labels2 = es_master.train_test_split(p=0.5, method="global")
        check(split2, ids2, labels2)
def cora():
    print(os.getcwd())
    if os.getcwd().split("/")[-1] == "tests":
        input_dir = os.path.expanduser("resources/data/cora/cora.epgm")
    else:
        input_dir = os.path.expanduser("tests/resources/data/cora/cora.epgm")

    dataset_name = "cora"

    g = read_graph(input_dir, dataset_name)
    g = nx.Graph(g)

    es_obj = EdgeSplitter(g)

    return g, es_obj
class TestEdgeSplitterHeterogeneous(object):

    g = create_heterogeneous_graph()

    es_obj = EdgeSplitter(g)

    def test_split_data_by_edge_type_and_attribute(self):
        # test global method for negative edge sampling
        self._test_split_data_by_edge_type_and_attribute(method="global")

        # test local method for positive edge sampling
        self._test_split_data_by_edge_type_and_attribute(method="local")

    def _test_split_data_by_edge_type_and_attribute(self, method):
        p = 0.1
        res = self.es_obj.train_test_split(
            p=p,
            method=method,
            keep_connected=True,
            edge_label="friend",
            edge_attribute_label="date",
            attribute_is_datetime=True,
            edge_attribute_threshold="01/01/2008",
        )
        g_test, edge_data_ids_test, edge_data_labels_test = res

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

        p = 0.8
        with pytest.raises(ValueError):
            # This will raise ValueError because it cannot sample enough positive edges while maintaining graph
            # connectivity
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/01/2008",
            )

        with pytest.raises(ValueError):
            # This will raise ValueError because it cannot sample enough negative edges of the given edge_label.
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=False,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/01/2008",
            )

        p = 0.1
        with pytest.raises(KeyError):
            # This call will raise an exception because the edges of type friend don't have attribute of type 'Any'
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="Any",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/01/2008",
            )
        with pytest.raises(KeyError):
            # This call will raise and exception because edges of type 'towards' don't have a 'date' attribute
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="published-at",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/01/2008",
            )

        with pytest.raises(ValueError):
            # This call will raise an exception because the edge attribute must be specified as datetime
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=False,
                edge_attribute_threshold="01/01/2008",
            )

        # Th below call will raise an exception because the threshold value does not have the correct format dd/mm/yyyy
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/2008",
            )
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="Jan 2005",
            )
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01-01-2000",
            )
        with pytest.raises(ValueError):
            # month is out of range; no such thing as a 14th month in a year
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/14/2008",
            )
        with pytest.raises(ValueError):
            # day is out of range; no such thing as a 32nd day in October
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="friend",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="32/10/2008",
            )

        with pytest.raises(Exception):
            # This call to train_test_split will raise an exception because all the edges of type 'writes' are
            # on the minimum spanning tree and cannot be removed.
            self.es_obj.train_test_split(
                p=p,
                method=method,
                keep_connected=True,
                edge_label="writes",
                edge_attribute_label="date",
                attribute_is_datetime=True,
                edge_attribute_threshold="01/01/2008",
            )

    def test_split_data_by_edge_type(self):
        # test global method for negative edge sampling
        self._test_split_data_by_edge_type(method="global")

        # test local method for positive edge sampling
        self._test_split_data_by_edge_type(method="local")

    def _test_split_data_by_edge_type(self, method):
        p = 0.1
        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method=method, edge_label="friend", keep_connected=True)

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

        with pytest.raises(Exception):
            # This call will raise an exception because the graph has no edges of type 'Non Label'
            self.es_obj.train_test_split(p=p,
                                         method=method,
                                         keep_connected=True,
                                         edge_label="No Label")

        p = 0.8
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p,
                                         method=method,
                                         edge_label="friend",
                                         keep_connected=True)

        p = 0.8
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p,
                                         method=method,
                                         edge_label="friend",
                                         keep_connected=False)

    def test_split_data_global(self):
        p = 0.1
        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method="global", keep_connected=True)

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

    def test_split_data_local(self):
        p = 0.1

        # using default sampling probabilities
        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method="local", keep_connected=True)

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)
class TestEdgeSplitterHomogeneous(object):
    print(os.getcwd())
    if os.getcwd().split("/")[-1] == "tests":
        input_dir = os.path.expanduser("resources/data/cora/cora.epgm")
    else:
        input_dir = os.path.expanduser("tests/resources/data/cora/cora.epgm")

    dataset_name = "cora"

    g = read_graph(input_dir, dataset_name)
    g = nx.Graph(g)

    es_obj = EdgeSplitter(g)

    def test_split_data_global(self):
        p = 0.1

        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method="global", keep_connected=True)

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

        with pytest.raises(ValueError):
            # This should raise ValueError because it is asking for more positive samples that are available
            # without breaking graph connectivity
            g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
                p=0.8, method="global", keep_connected=True)

    def test_split_data_local(self):
        p = 0.1
        # using default sampling probabilities
        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method="local", keep_connected=True)

        # if all goes well, what are the expected return values?
        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

        sampling_probs = [0.0, 0.0, 0.1, 0.2, 0.5, 0.2]
        g_test, edge_data_ids_test, edge_data_labels_test = self.es_obj.train_test_split(
            p=p, method="local", probs=sampling_probs, keep_connected=True)

        num_sampled_positives = np.sum(edge_data_labels_test == 1)
        num_sampled_negatives = np.sum(edge_data_labels_test == 0)

        assert num_sampled_positives > 0
        assert num_sampled_negatives > 0
        assert len(edge_data_ids_test) == len(edge_data_labels_test)
        assert (num_sampled_positives - num_sampled_negatives) == 0
        assert len(g_test.edges()) < len(self.g.edges())
        assert nx.is_connected(g_test)

        with pytest.raises(ValueError):
            # This should raise ValueError because it is asking for more positive samples that are available
            # without breaking graph connectivity
            self.es_obj.train_test_split(p=0.8,
                                         method="local",
                                         probs=sampling_probs,
                                         keep_connected=True)

        sampling_probs = [0.2, 0.1, 0.2, 0.5, 0.2]  # values don't sum to 1
        with pytest.raises(ValueError):
            self.es_obj.train_test_split(p=p,
                                         method="local",
                                         probs=sampling_probs)