def PopulateWithTestSet( db: graph_tuple_database.Database, graph_count: int, node_x_dimensionality: int = 2, node_y_dimensionality: int = 0, graph_x_dimensionality: int = 0, graph_y_dimensionality: int = 0, with_data_flow: bool = False, split_count: int = 0, ): """Populate a database with "real" programs.""" rows = [] graph_tuples = itertools.islice( itertools.cycle( random_graph_tuple_generator.EnumerateTestSet(n=graph_count)), graph_count, ) for i, graph_tuple in enumerate(graph_tuples): # Set the graph labels. node_x = (np.random.randint(low=0, high=2, size=(graph_tuple.node_count, node_x_dimensionality)) if node_x_dimensionality else None) node_y = (np.random.rand(graph_tuple.node_count, node_y_dimensionality) if node_y_dimensionality else None) graph_x = (np.random.randint( low=0, high=51, size=graph_x_dimensionality) if graph_x_dimensionality else None) graph_y = (np.random.rand(graph_tuple.graph_count, graph_y_dimensionality) if graph_y_dimensionality else None) graph_tuple = graph_tuple.SetFeaturesAndLabels(node_x=node_x, node_y=node_y, graph_x=graph_x, graph_y=graph_y, copy=False) mapped = graph_tuple_database.GraphTuple.CreateFromGraphTuple( graph_tuple, ir_id=i + 1, split=random.randint(0, split_count) if split_count else None, ) if with_data_flow: mapped.data_flow_steps = random.randint(1, 50) mapped.data_flow_root_node = random.randint( 0, mapped.node_count - 1) mapped.data_flow_positive_node_count = random.randint( 1, mapped.node_count - 1) rows.append(mapped) with db.Session(commit=True) as session: session.add_all(rows) return DatabaseAndRows(db, rows)
def test_EnumerateTestSet(): """Test the "real" protos.""" protos = list(random_graph_tuple_generator.EnumerateTestSet()) assert len(protos) == 100