예제 #1
0
def StratifiedKFold(db: graph_database.Database, num_splits: int):
    """Apply a stratified K-fold split on the graph database."""
    with db.Session() as session:
        num_graphs = session.query(sql.func.count(
            graph_database.GraphMeta.id)).one()[0]
        with prof.Profile(f"Loaded labels from {num_graphs} graphs"):
            # Load all graphs as a single batch. WARNING this will not work for large
            # datasets!
            batcher = graph_batcher.GraphBatcher(db)
            options = graph_batcher.GraphBatchOptions(max_graphs=num_graphs +
                                                      1)
            graph_batches = list(batcher.MakeGraphBatchIterator(options))
            assert len(graph_batches) == 1
            graph_batch = graph_batches[0]

            graph_ids = np.array(
                graph_batch.log._transient_data["graph_indices"],
                dtype=np.int32)
            # Compute the dense labels from one-hot vectors.
            labels = np.argmax(graph_batch.graph_y, axis=1)

    # Split the graphs
    seed = 0xCEC
    splitter = model_selection.StratifiedKFold(n_splits=num_splits,
                                               shuffle=True,
                                               random_state=seed)
    dataset_splits = splitter.split(graph_ids, labels)

    groups = {
        str(i): graph_ids[test]
        for i, (train, test) in enumerate(dataset_splits)
    }
    return groups
예제 #2
0
def test_GraphBatcher_collect_all_inputs(graph_count: int):
    batcher = graph_batcher.GraphBatcher(
        MockIterator([
            random_graph_tuple_generator.CreateRandomGraphTuple()
            for _ in range(graph_count)
        ]))
    batches = list(batcher)
    assert len(batches) == 1
    assert batches[0].is_disjoint_graph
    assert batches[0].disjoint_graph_count == graph_count
예제 #3
0
def test_GraphBatcher_max_node_count_limit_handler_skip():
    """Test that graph is included when larger than max node count."""
    big_graph = random_graph_tuple_generator.CreateRandomGraphTuple(
        node_count=10)

    batcher = graph_batcher.GraphBatcher(
        MockIterator([big_graph]),
        max_node_count=5,
        max_node_count_limit_handler="include",
    )

    assert next(batcher)
예제 #4
0
def test_GraphBatcher_max_node_count_limit_handler_error():
    """Test that error is raised when graph is larger than max node count."""
    big_graph = random_graph_tuple_generator.CreateRandomGraphTuple(
        node_count=10)

    batcher = graph_batcher.GraphBatcher(
        MockIterator([big_graph]),
        max_node_count=5,
        max_node_count_limit_handler="error",
    )

    with test.Raises(ValueError):
        next(batcher)
예제 #5
0
def test_fuzz_GraphBatcher(graph_count: int, max_graph_count: int,
                           max_node_count: int):
    """Fuzz the graph batcher with a range of parameter choices and input
  sizes.
  """
    graphs = MockIterator([
        random_graph_tuple_generator.CreateRandomGraphTuple()
        for _ in range(graph_count)
    ])
    batcher = graph_batcher.GraphBatcher(graphs,
                                         max_node_count=max_node_count,
                                         max_graph_count=max_graph_count)
    batches = list(batcher)
    assert sum(b.disjoint_graph_count for b in batches) == graph_count
예제 #6
0
def test_GraphBatcher_exact_graph_count():
    """Test the number of batches when exact graph counts are required."""
    batcher = graph_batcher.GraphBatcher(
        MockIterator([
            random_graph_tuple_generator.CreateRandomGraphTuple()
            for _ in range(7)
        ]),
        exact_graph_count=3,
    )

    batches = list(batcher)
    assert len(batches) == 2
    assert batches[0].disjoint_graph_count == 3
    assert batches[1].disjoint_graph_count == 3
예제 #7
0
def Main():
    irs = [fs.Read(path) for path in LLVM_IR.iterdir()]
    ir_count = len(irs)

    with prof.ProfileToStdout(lambda t: (
            f"STAGE 1: Construct unlabelled graphs (llvm2graph)         "
            f"({humanize.Duration(t / ir_count)} / IR)")):
        graphs = [llvm2graph.BuildProgramGraphNetworkX(ir) for ir in irs]

    encoder = node_encoder.GraphNodeEncoder()
    with prof.ProfileToStdout(lambda t: (
            f"STAGE 2: Encode graphs (inst2vec)                         "
            f"({humanize.Duration(t / ir_count)} / IR)")):
        for graph, ir in zip(graphs, irs):
            encoder.EncodeNodes(graph, ir)

    features_count = 0
    features_lists = []
    with prof.ProfileToStdout(lambda t: (
            f"STAGE 3: Produce labelled graphs (reachability analysis)  "
            f"({humanize.Duration(t / features_count)} / graph)")):
        for graph in graphs:
            analysis = reachability.ReachabilityAnnotator(
                programl.NetworkXToProgramGraph(graph))
            features_list = analysis.MakeAnnotated(n=10).graphs
            features_count += len(features_list)
            features_lists.append(features_list)

    def iter():
        for features_list in features_lists:
            for graph in features_list:
                yield graph_tuple.GraphTuple.CreateFromNetworkX(graph)

    with prof.ProfileToStdout(lambda t: (
            f"STAGE 4: Construct graph tuples                           "
            f"({humanize.Duration(t / features_count)} / graph)")):
        batcher = graph_batcher.GraphBatcher(iter(), max_node_count=10000)
        graph_tuples = list(batcher)

    print("=================================")
    print(f"Unlabelled graphs count: {ir_count}")
    print(f"  Labelled graphs count: {features_count}")
    print(f"     Graph tuples count: {len(graph_tuples)}")
    print(
        f"       Total node count: {sum(gt.node_count for gt in graph_tuples)}"
    )
    print(
        f"       Total edge count: {sum(gt.edge_count for gt in graph_tuples)}"
    )
예제 #8
0
def test_GraphBatcher_max_node_count_limit_handler_skip():
    """Test that graph is skipped when larger than max node count."""
    big_graph = random_graph_tuple_generator.CreateRandomGraphTuple(
        node_count=10)

    batcher = graph_batcher.GraphBatcher(
        MockIterator([big_graph]),
        max_node_count=5,
        max_node_count_limit_handler="skip",
    )

    try:
        next(batcher)
    except StopIteration:
        pass
예제 #9
0
def test_GraphBatcher_divisible_node_count():
    """Test the number of batches returned with evenly divisible node counts."""
    batcher = graph_batcher.GraphBatcher(
        MockIterator([
            random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5),
            random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5),
            random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5),
            random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5),
        ]),
        max_node_count=10,
    )

    batches = list(batcher)
    assert len(batches) == 2
    assert batches[0].is_disjoint_graph
    assert batches[0].disjoint_graph_count == 2
    assert batches[1].is_disjoint_graph
    assert batches[1].disjoint_graph_count == 2
예제 #10
0
def test_GraphBatcher_empty_graphs_list():
    """Test input with empty graph """
    batcher = graph_batcher.GraphBatcher(MockIterator([]))
    with test.Raises(StopIteration):
        next(batcher)