def MakeBatch( self, epoch_type: epoch.Type, graphs: Iterable[graph_tuple_database.GraphTuple], ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Data: self._batch_count += 1 if self._batch_count == 1: # Always return an empty first batch. return batches.EmptyBatch() elif self._batch_count == 2: # Always return a real second batch (otherwise the epoch may end up with # nothing but empty batches). return super(FlakyBatchModel, self).MakeBatch(epoch_type, graphs, ctx) # Return subsequent batches with 50% success rate. if random.random() < 0.5: return super(FlakyBatchModel, self).MakeBatch(epoch_type, graphs, ctx) else: return batches.EmptyBatch()
def MakeBatch( self, epoch_type: epoch.Type, graphs: Iterable[graph_tuple_database.GraphTuple], ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Data: """Create a mini-batch of data from an iterator of graphs. Returns: A single batch of data for feeding into RunBatch(). A batch consists of a list of graph IDs and a model-defined blob of data. If the list of graph IDs is empty, the batch is discarded and not fed into RunBatch(). """ # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher # implementation is not well suited for reading the graph IDs, hence this # somewhat clumsy iterator wrapper. A neater approach would be to create # a graph batcher which returns a list of graphs in the batch. class GraphIterator(object): """A wrapper around a graph iterator which records graph IDs.""" def __init__(self, graphs: Iterable[graph_tuple_database.GraphTuple]): self.input_graphs = graphs self.graphs_read: List[graph_tuple_database.GraphTuple] = [] def __iter__(self): return self def __next__(self): graph: graph_tuple_database.GraphTuple = next(self.input_graphs) self.graphs_read.append(graph) return graph.tuple graph_iterator = GraphIterator(graphs) # Create a disjoint graph out of one or more input graphs. batcher = graph_batcher.GraphBatcher.CreateFromFlags( graph_iterator, ctx=ctx ) try: disjoint_graph = next(batcher) except StopIteration: # We have run out of graphs. return batches.EndOfBatches() # Workaround for the fact that graph batcher may read one more graph than # actually gets included in the batch. if batcher.last_graph: graphs = graph_iterator.graphs_read[:-1] else: graphs = graph_iterator.graphs_read # Discard single-graph batches during training when there are graph # features. This is because we use batch normalization on incoming features, # and batch normalization requires > 1 items to normalize. if ( len(graphs) <= 1 and epoch_type == epoch.Type.TRAIN and disjoint_graph.graph_x_dimensionality ): return batches.EmptyBatch() return batches.Data( graph_ids=[graph.id for graph in graphs], data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs), )