예제 #1
0
    def test_node_attributes(self):
        """Test the .node_attributes() method"""
        G_epgm = EPGM(self.input_dir)
        graph_id = G_epgm.G["graphs"][0]["id"]

        # dataset has 1 unique 'user' node attribute, 'elite'
        node_attributes = G_epgm.node_attributes(graph_id, self.node_type)

        assert self.target_attribute in node_attributes
        assert (len(node_attributes) == 1
                ), "There should be 1 unique node attribute; found {}".format(
                    len(node_attributes))

        # passing a non-existent node type should return an empty array of node attributes:
        assert len(G_epgm.node_attributes(graph_id, "business")) == 0

        # if node_type is not supplied, a TypeError should be raised:
        with pytest.raises(TypeError):
            G_epgm.node_attributes(graph_id)
예제 #2
0
def load_data(path, dataset_name=None, node_type=None, target_attribute=None):
    """
    Loads the node data

     :param path: Input filename or directory where graph in EPGM format is stored
     :param node_type: For HINs, the node type to consider
     :param target_attribute: For EPGM format, the target node attribute
     :return: N x 2 numpy arrays where the first column is the node id and the second column is the node label.
    """
    if os.path.isdir(path):
        g_epgm = EPGM(path)
        graphs = g_epgm.G["graphs"]
        for g in graphs:
            if g["meta"]["label"] == dataset_name:
                g_id = g["id"]

        g_vertices = g_epgm.G["vertices"]  # retrieve all graph vertices

        if node_type is None:
            node_type = g_epgm.node_types(g_id)
            if len(node_type) == 1:
                node_type = node_type[0]
            else:
                raise Exception(
                    "Multiple node types detected in graph {}: {}.".format(
                        g_id, node_type
                    )
                )

        if target_attribute is None:
            target_attribute = g_epgm.node_attributes(g_id, node_type)
            if len(target_attribute) == 1:
                target_attribute = target_attribute[0]
            else:
                raise Exception(
                    "Multiple node attributes detected for nodes of type {} in graph {}: {}.".format(
                        node_type, g_id, target_attribute
                    )
                )

        y = np.array(
            get_nodes(
                g_vertices, node_type=node_type, target_attribute=target_attribute
            )
        )

    else:
        y_df = pd.read_csv(path, delimiter=" ", header=None, dtype=str)
        y_df.sort_values(by=[0], inplace=True)

        y = y_df.values

    return y