def test_empty_batches_is_not_error( graph_db: graph_tuple_database.Database, logger: logging.Logger, ): """Test that empty batches are ignored. Regression test for <github.com/ChrisCummins/ProGraML/issues/43>. Empty batch generation was determined to be the cause of flaky model crashes. """ class FlakyBatchModel(MockModel): """A mock model which returns ~50% empty batches.""" def __init__(self, *args, **kwargs): super(FlakyBatchModel, self).__init__(*args, **kwargs) self._batch_count = 0 def MakeBatch( self, epoch_type: epoch.Type, graphs: Iterable[graph_tuple_database.GraphTuple], ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Data: self._batch_count += 1 if self._batch_count == 1: # Always return an empty first batch. return batches.EmptyBatch() elif self._batch_count == 2: # Always return a real second batch (otherwise the epoch may end up with # nothing but empty batches). return super(FlakyBatchModel, self).MakeBatch(epoch_type, graphs, ctx) # Return subsequent batches with 50% success rate. if random.random() < 0.5: return super(FlakyBatchModel, self).MakeBatch(epoch_type, graphs, ctx) else: return batches.EmptyBatch() run_id = run_id_lib.RunId.GenerateUnique( f"mock{random.randint(0, int(1e6)):06}") model = FlakyBatchModel( logger=logger, graph_db=graph_db, run_id=run_id, ) batch_iterator = batches.BatchIterator( batches=model.BatchIterator( epoch.Type.TRAIN, graph_database_reader.BufferedGraphReader(graph_db)), graph_count=graph_db.graph_count, ) model.Initialize() results = model(epoch_type=epoch.Type.TRAIN, batch_iterator=batch_iterator, logger=logger) assert results.batch_count >= 1
def MakeBatchIterator( model: classifier_base.ClassifierBase, graph_db: graph_tuple_database.Database, splits: Dict[epoch.Type, List[int]], epoch_type: epoch.Type, ctx: progress.ProgressContext = progress.NullContext, ) -> batches.BatchIterator: """Create an iterator over batches for the given epoch type. Args: model: The model to generate a batch iterator for. splits: A mapping from epoch type to a list of split numbers. epoch_type: The type of epoch to produce an iterator for. ctx: A logging context. Returns: A batch iterator for feeding into model.RunBatch(). """ # Filter the graph database to load graphs from the requested splits. if epoch_type == epoch.Type.TRAIN: limit = FLAGS.max_train_per_epoch elif epoch_type == epoch.Type.VAL: limit = FLAGS.max_val_per_epoch elif epoch_type == epoch.Type.TEST: limit = None # Never limit the test set. else: raise NotImplementedError("unreachable") splits_for_type = splits[epoch_type] ctx.Log( 3, "Using %s graph splits %s", epoch_type.name.lower(), sorted(splits_for_type), ) if len(splits_for_type) == 1: split_filter = (lambda: graph_tuple_database.GraphTuple.split == splits_for_type[0]) else: split_filter = lambda: graph_tuple_database.GraphTuple.split.in_( splits_for_type) graph_reader = model.GraphReader( epoch_type=epoch_type, graph_db=graph_db, filters=[split_filter], limit=limit, ctx=ctx, ) return batches.BatchIterator( batches=ppar.ThreadedIterator( model.BatchIterator(epoch_type, graph_reader, ctx=ctx), max_queue_size=FLAGS.batch_queue_size, ), graph_count=graph_reader.n, )
def batch_iterator( model: MockModel, graph_db: graph_tuple_database.Database, ) -> batches.BatchIterator: return batches.BatchIterator( batches=model.BatchIterator( epoch.Type.TRAIN, graph_database_reader.BufferedGraphReader(graph_db)), graph_count=graph_db.graph_count, )