def StratifiedKFold(db: graph_database.Database, num_splits: int): """Apply a stratified K-fold split on the graph database.""" with db.Session() as session: num_graphs = session.query(sql.func.count( graph_database.GraphMeta.id)).one()[0] with prof.Profile(f"Loaded labels from {num_graphs} graphs"): # Load all graphs as a single batch. WARNING this will not work for large # datasets! batcher = graph_batcher.GraphBatcher(db) options = graph_batcher.GraphBatchOptions(max_graphs=num_graphs + 1) graph_batches = list(batcher.MakeGraphBatchIterator(options)) assert len(graph_batches) == 1 graph_batch = graph_batches[0] graph_ids = np.array( graph_batch.log._transient_data["graph_indices"], dtype=np.int32) # Compute the dense labels from one-hot vectors. labels = np.argmax(graph_batch.graph_y, axis=1) # Split the graphs seed = 0xCEC splitter = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=seed) dataset_splits = splitter.split(graph_ids, labels) groups = { str(i): graph_ids[test] for i, (train, test) in enumerate(dataset_splits) } return groups
def test_GraphBatcher_collect_all_inputs(graph_count: int): batcher = graph_batcher.GraphBatcher( MockIterator([ random_graph_tuple_generator.CreateRandomGraphTuple() for _ in range(graph_count) ])) batches = list(batcher) assert len(batches) == 1 assert batches[0].is_disjoint_graph assert batches[0].disjoint_graph_count == graph_count
def test_GraphBatcher_max_node_count_limit_handler_skip(): """Test that graph is included when larger than max node count.""" big_graph = random_graph_tuple_generator.CreateRandomGraphTuple( node_count=10) batcher = graph_batcher.GraphBatcher( MockIterator([big_graph]), max_node_count=5, max_node_count_limit_handler="include", ) assert next(batcher)
def test_GraphBatcher_max_node_count_limit_handler_error(): """Test that error is raised when graph is larger than max node count.""" big_graph = random_graph_tuple_generator.CreateRandomGraphTuple( node_count=10) batcher = graph_batcher.GraphBatcher( MockIterator([big_graph]), max_node_count=5, max_node_count_limit_handler="error", ) with test.Raises(ValueError): next(batcher)
def test_fuzz_GraphBatcher(graph_count: int, max_graph_count: int, max_node_count: int): """Fuzz the graph batcher with a range of parameter choices and input sizes. """ graphs = MockIterator([ random_graph_tuple_generator.CreateRandomGraphTuple() for _ in range(graph_count) ]) batcher = graph_batcher.GraphBatcher(graphs, max_node_count=max_node_count, max_graph_count=max_graph_count) batches = list(batcher) assert sum(b.disjoint_graph_count for b in batches) == graph_count
def test_GraphBatcher_exact_graph_count(): """Test the number of batches when exact graph counts are required.""" batcher = graph_batcher.GraphBatcher( MockIterator([ random_graph_tuple_generator.CreateRandomGraphTuple() for _ in range(7) ]), exact_graph_count=3, ) batches = list(batcher) assert len(batches) == 2 assert batches[0].disjoint_graph_count == 3 assert batches[1].disjoint_graph_count == 3
def Main(): irs = [fs.Read(path) for path in LLVM_IR.iterdir()] ir_count = len(irs) with prof.ProfileToStdout(lambda t: ( f"STAGE 1: Construct unlabelled graphs (llvm2graph) " f"({humanize.Duration(t / ir_count)} / IR)")): graphs = [llvm2graph.BuildProgramGraphNetworkX(ir) for ir in irs] encoder = node_encoder.GraphNodeEncoder() with prof.ProfileToStdout(lambda t: ( f"STAGE 2: Encode graphs (inst2vec) " f"({humanize.Duration(t / ir_count)} / IR)")): for graph, ir in zip(graphs, irs): encoder.EncodeNodes(graph, ir) features_count = 0 features_lists = [] with prof.ProfileToStdout(lambda t: ( f"STAGE 3: Produce labelled graphs (reachability analysis) " f"({humanize.Duration(t / features_count)} / graph)")): for graph in graphs: analysis = reachability.ReachabilityAnnotator( programl.NetworkXToProgramGraph(graph)) features_list = analysis.MakeAnnotated(n=10).graphs features_count += len(features_list) features_lists.append(features_list) def iter(): for features_list in features_lists: for graph in features_list: yield graph_tuple.GraphTuple.CreateFromNetworkX(graph) with prof.ProfileToStdout(lambda t: ( f"STAGE 4: Construct graph tuples " f"({humanize.Duration(t / features_count)} / graph)")): batcher = graph_batcher.GraphBatcher(iter(), max_node_count=10000) graph_tuples = list(batcher) print("=================================") print(f"Unlabelled graphs count: {ir_count}") print(f" Labelled graphs count: {features_count}") print(f" Graph tuples count: {len(graph_tuples)}") print( f" Total node count: {sum(gt.node_count for gt in graph_tuples)}" ) print( f" Total edge count: {sum(gt.edge_count for gt in graph_tuples)}" )
def test_GraphBatcher_max_node_count_limit_handler_skip(): """Test that graph is skipped when larger than max node count.""" big_graph = random_graph_tuple_generator.CreateRandomGraphTuple( node_count=10) batcher = graph_batcher.GraphBatcher( MockIterator([big_graph]), max_node_count=5, max_node_count_limit_handler="skip", ) try: next(batcher) except StopIteration: pass
def test_GraphBatcher_divisible_node_count(): """Test the number of batches returned with evenly divisible node counts.""" batcher = graph_batcher.GraphBatcher( MockIterator([ random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5), random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5), random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5), random_graph_tuple_generator.CreateRandomGraphTuple(node_count=5), ]), max_node_count=10, ) batches = list(batcher) assert len(batches) == 2 assert batches[0].is_disjoint_graph assert batches[0].disjoint_graph_count == 2 assert batches[1].is_disjoint_graph assert batches[1].disjoint_graph_count == 2
def test_GraphBatcher_empty_graphs_list(): """Test input with empty graph """ batcher = graph_batcher.GraphBatcher(MockIterator([])) with test.Raises(StopIteration): next(batcher)