def SelectRandomGraphs(graph_db: graph_tuple_database.Database): """Return [1, graph_db.graph_count] graphs in a random order.""" with graph_db.Session() as session: # Load a random collection of graphs. graphs = (session.query( graph_tuple_database.GraphTuple).order_by(graph_db.Random()).limit( random.randint(1, graph_db.graph_count)).all()) # Sanity check that graphs are returned. assert graphs return graphs
def __init__( self, db: graph_tuple_database.Database, buffer_size_mb: int = 16, filters: Optional[List[Callable[[], bool]]] = None, order: BufferedGraphReaderOrder = BufferedGraphReaderOrder.IN_ORDER, eager_graph_loading: bool = True, limit: Optional[int] = None, ctx: progress.ProgressContext = progress.NullContext, ): """Constructor. Args: db: The database to iterate over. filters: An optional list of callbacks, where each callback returns a filter condition on the GraphTuple table. order: Determine the order to read graphs. See BufferedGraphReaderOrder. eager_graph_loading: If true, load the contents of the Graph table eagerly, preventing the need for subsequent SQL queries to access the graph data. buffer_size_mb: The number of graphs to query from the database at a time. A larger number reduces the number of queries, but increases the memory requirement. limit: Limit the total number of rows returned to this value. Raises: ValueError: If the query with the given filters returns no results. """ self.db = db self.order = order self.max_buffer_size = buffer_size_mb * 1024 * 1024 self.eager_graph_loading = eager_graph_loading self.filters = filters or [] self.ctx = ctx # Graphs that fail during dataset generation are inserted as zero-node # entries. Ignore those. self.filters.append(lambda: graph_tuple_database.GraphTuple.node_count > 1) if not self.db.graph_count: raise ValueError(f"Database contains no graphs: {self.db.url}") with ctx.Profile( 3, lambda _: ( f"Selected {humanize.Commas(self.n)} of " f"{humanize.Commas(self.db.graph_count)} graphs from database" ), ): with db.Session() as session: # Random ordering means that we can't use # labm8.py.sqlutil.OffsetLimitBatchedQuery() to read results as each # query will produce a different random order. Instead, first run a # query to read all of the IDs and the corresponding tuple sizes that # match the query, then iterate through the list of IDs. query = session.query( graph_tuple_database.GraphTuple.id, graph_tuple_database.GraphTuple.pickled_graph_tuple_size.label( "size" ), ) # Apply the requested filters. for filter_cb in self.filters: query = query.filter(filter_cb()) # If we are ordering with global random then we can scan through the # graph table using index range checks, so we need the IDs sorted. if order == BufferedGraphReaderOrder.DATA_FLOW_STEPS: self.ordered_ids = False query = query.order_by( graph_tuple_database.GraphTuple.data_flow_steps ) elif order == BufferedGraphReaderOrder.GLOBAL_RANDOM: self.ordered_ids = False query = query.order_by(db.Random()) else: self.ordered_ids = True query = query.order_by(graph_tuple_database.GraphTuple.id) # Read the full set of graph IDs and sizes. self.ids_and_sizes = [(row.id, row.size) for row in query.all()] if not self.ids_and_sizes: raise ValueError( f"Query on database `{db.url}` returned no results: " f"`{sqlutil.QueryToString(query)}`" ) # When we are limiting the number of rows and not reading the table in # order, pick a random starting point in the list of IDs. if limit and order != BufferedGraphReaderOrder.IN_ORDER: batch_start = random.randint( 0, max(len(self.ids_and_sizes) - limit - 1, 0) ) self.ids_and_sizes = self.ids_and_sizes[ batch_start : batch_start + limit ] elif limit: # If we are reading the table in order, we must still respect the limit # argument. self.ids_and_sizes = self.ids_and_sizes[:limit] self.i = 0 self.n = len(self.ids_and_sizes) # The local buffer of graphs, and an index into that buffer. self.buffer: List[graph_tuple_database.GraphTuple] = [] self.buffer_i = 0