def test_CreateEmpty(db_session: graph_tuple_database.Database.SessionType):
    """Test creation of empty graph tuple."""
    a = graph_tuple_database.GraphTuple.CreateEmpty(ir_id=1)
    assert a.ir_id == 1

    db_session.add(a)
    db_session.commit()
def test_CreateFromGraphTuple_node_x_dimensionality(
    db_session: graph_tuple_database.Database.SessionType, ):
    """Test node feature dimensionality."""
    graph_tuple = random_graph_tuple_generator.CreateRandomGraphTuple()
    a = graph_tuple_database.GraphTuple.CreateFromGraphTuple(graph_tuple,
                                                             ir_id=1)
    assert a.node_x_dimensionality == 1
    db_session.add(a)
    db_session.commit()
def test_CreateFromGraphTuple_graph_y_dimensionality(
    db_session: graph_tuple_database.Database.SessionType, ):
    """Check graph label dimensionality."""
    graph_tuple = random_graph_tuple_generator.CreateRandomGraphTuple(
        graph_y_dimensionality=0)
    a = graph_tuple_database.GraphTuple.CreateFromGraphTuple(graph_tuple,
                                                             ir_id=1)
    assert a.graph_y_dimensionality == 0

    graph_tuple = random_graph_tuple_generator.CreateRandomGraphTuple(
        graph_y_dimensionality=2)
    b = graph_tuple_database.GraphTuple.CreateFromGraphTuple(graph_tuple,
                                                             ir_id=1)
    assert b.graph_y_dimensionality == 2

    db_session.add_all([a, b])
    db_session.commit()
def test_fuzz_GraphTuple_CreateFromNetworkX(
    db_session: graph_tuple_database.Database.SessionType, ):
    """Fuzz the networkx -> proto conversion using randomly generated graphs."""
    g = random_networkx_generator.CreateRandomGraph()
    t = graph_tuple_database.GraphTuple.CreateFromNetworkX(
        g=g, ir_id=random.randint(0, int(4e6)))

    # Test the derived properties of the generated graph tuple.
    assert t.edge_count == (t.control_edge_count + t.data_edge_count +
                            t.call_edge_count)
    assert len(t.sha1) == 40
    assert t.node_count == g.number_of_nodes()
    assert t.edge_count == g.number_of_edges()
    assert t.tuple.node_count == g.number_of_nodes()
    assert t.tuple.edge_count == g.number_of_edges()
    assert len(t.tuple.adjacencies) == 3
    assert len(t.tuple.edge_positions) == 3

    # Add it to the database to catch SQL integrity errors.
    db_session.add(t)
    db_session.commit()