コード例 #1
0
        def MakeBatch(
            self,
            epoch_type: epoch.Type,
            graphs: Iterable[graph_tuple_database.GraphTuple],
            ctx: progress.ProgressContext = progress.NullContext,
        ) -> batches.Data:
            self._batch_count += 1

            if self._batch_count == 1:
                # Always return an empty first batch.
                return batches.EmptyBatch()
            elif self._batch_count == 2:
                # Always return a real second batch (otherwise the epoch may end up with
                # nothing but empty batches).
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)

            # Return subsequent batches with 50% success rate.
            if random.random() < 0.5:
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)
            else:
                return batches.EmptyBatch()
コード例 #2
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graphs: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of data from an iterator of graphs.

    Returns:
      A single batch of data for feeding into RunBatch(). A batch consists of a
      list of graph IDs and a model-defined blob of data. If the list of graph
      IDs is empty, the batch is discarded and not fed into RunBatch().
    """
    # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher
    # implementation is not well suited for reading the graph IDs, hence this
    # somewhat clumsy iterator wrapper. A neater approach would be to create
    # a graph batcher which returns a list of graphs in the batch.
    class GraphIterator(object):
      """A wrapper around a graph iterator which records graph IDs."""

      def __init__(self, graphs: Iterable[graph_tuple_database.GraphTuple]):
        self.input_graphs = graphs
        self.graphs_read: List[graph_tuple_database.GraphTuple] = []

      def __iter__(self):
        return self

      def __next__(self):
        graph: graph_tuple_database.GraphTuple = next(self.input_graphs)
        self.graphs_read.append(graph)
        return graph.tuple

    graph_iterator = GraphIterator(graphs)

    # Create a disjoint graph out of one or more input graphs.
    batcher = graph_batcher.GraphBatcher.CreateFromFlags(
      graph_iterator, ctx=ctx
    )

    try:
      disjoint_graph = next(batcher)
    except StopIteration:
      # We have run out of graphs.
      return batches.EndOfBatches()

    # Workaround for the fact that graph batcher may read one more graph than
    # actually gets included in the batch.
    if batcher.last_graph:
      graphs = graph_iterator.graphs_read[:-1]
    else:
      graphs = graph_iterator.graphs_read

    # Discard single-graph batches during training when there are graph
    # features. This is because we use batch normalization on incoming features,
    # and batch normalization requires > 1 items to normalize.
    if (
      len(graphs) <= 1
      and epoch_type == epoch.Type.TRAIN
      and disjoint_graph.graph_x_dimensionality
    ):
      return batches.EmptyBatch()

    return batches.Data(
      graph_ids=[graph.id for graph in graphs],
      data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs),
    )