示例#1
0
def SelectRandomGraphs(graph_db: graph_tuple_database.Database):
    """Return [1, graph_db.graph_count] graphs in a random order."""
    with graph_db.Session() as session:
        # Load a random collection of graphs.
        graphs = (session.query(
            graph_tuple_database.GraphTuple).order_by(graph_db.Random()).limit(
                random.randint(1, graph_db.graph_count)).all())
        # Sanity check that graphs are returned.
        assert graphs

    return graphs
  def __init__(
    self,
    db: graph_tuple_database.Database,
    buffer_size_mb: int = 16,
    filters: Optional[List[Callable[[], bool]]] = None,
    order: BufferedGraphReaderOrder = BufferedGraphReaderOrder.IN_ORDER,
    eager_graph_loading: bool = True,
    limit: Optional[int] = None,
    ctx: progress.ProgressContext = progress.NullContext,
  ):
    """Constructor.

    Args:
      db: The database to iterate over.
      filters: An optional list of callbacks, where each callback returns a
        filter condition on the GraphTuple table.
      order: Determine the order to read graphs. See BufferedGraphReaderOrder.
      eager_graph_loading: If true, load the contents of the Graph table eagerly,
        preventing the need for subsequent SQL queries to access the graph data.
      buffer_size_mb: The number of graphs to query from the database at a time. A
        larger number reduces the number of queries, but increases the memory
        requirement.
      limit: Limit the total number of rows returned to this value.

    Raises:
      ValueError: If the query with the given filters returns no results.
    """
    self.db = db
    self.order = order
    self.max_buffer_size = buffer_size_mb * 1024 * 1024
    self.eager_graph_loading = eager_graph_loading
    self.filters = filters or []
    self.ctx = ctx

    # Graphs that fail during dataset generation are inserted as zero-node
    # entries. Ignore those.
    self.filters.append(lambda: graph_tuple_database.GraphTuple.node_count > 1)

    if not self.db.graph_count:
      raise ValueError(f"Database contains no graphs: {self.db.url}")

    with ctx.Profile(
      3,
      lambda _: (
        f"Selected {humanize.Commas(self.n)} of "
        f"{humanize.Commas(self.db.graph_count)} graphs from database"
      ),
    ):
      with db.Session() as session:
        # Random ordering means that we can't use
        # labm8.py.sqlutil.OffsetLimitBatchedQuery() to read results as each
        # query will produce a different random order. Instead, first run a
        # query to read all of the IDs and the corresponding tuple sizes that
        # match the query, then iterate through the list of IDs.
        query = session.query(
          graph_tuple_database.GraphTuple.id,
          graph_tuple_database.GraphTuple.pickled_graph_tuple_size.label(
            "size"
          ),
        )

        # Apply the requested filters.
        for filter_cb in self.filters:
          query = query.filter(filter_cb())

        # If we are ordering with global random then we can scan through the
        # graph table using index range checks, so we need the IDs sorted.
        if order == BufferedGraphReaderOrder.DATA_FLOW_STEPS:
          self.ordered_ids = False
          query = query.order_by(
            graph_tuple_database.GraphTuple.data_flow_steps
          )
        elif order == BufferedGraphReaderOrder.GLOBAL_RANDOM:
          self.ordered_ids = False
          query = query.order_by(db.Random())
        else:
          self.ordered_ids = True
          query = query.order_by(graph_tuple_database.GraphTuple.id)

        # Read the full set of graph IDs and sizes.
        self.ids_and_sizes = [(row.id, row.size) for row in query.all()]

      if not self.ids_and_sizes:
        raise ValueError(
          f"Query on database `{db.url}` returned no results: "
          f"`{sqlutil.QueryToString(query)}`"
        )

      # When we are limiting the number of rows and not reading the table in
      # order, pick a random starting point in the list of IDs.
      if limit and order != BufferedGraphReaderOrder.IN_ORDER:
        batch_start = random.randint(
          0, max(len(self.ids_and_sizes) - limit - 1, 0)
        )
        self.ids_and_sizes = self.ids_and_sizes[
          batch_start : batch_start + limit
        ]
      elif limit:
        # If we are reading the table in order, we must still respect the limit
        # argument.
        self.ids_and_sizes = self.ids_and_sizes[:limit]

      self.i = 0
      self.n = len(self.ids_and_sizes)

      # The local buffer of graphs, and an index into that buffer.
      self.buffer: List[graph_tuple_database.GraphTuple] = []
      self.buffer_i = 0