def test_empty_batches_is_not_error(
    graph_db: graph_tuple_database.Database,
    logger: logging.Logger,
):
    """Test that empty batches are ignored.

  Regression test for <github.com/ChrisCummins/ProGraML/issues/43>. Empty batch
  generation was determined to be the cause of flaky model crashes.
  """
    class FlakyBatchModel(MockModel):
        """A mock model which returns ~50% empty batches."""
        def __init__(self, *args, **kwargs):
            super(FlakyBatchModel, self).__init__(*args, **kwargs)
            self._batch_count = 0

        def MakeBatch(
            self,
            epoch_type: epoch.Type,
            graphs: Iterable[graph_tuple_database.GraphTuple],
            ctx: progress.ProgressContext = progress.NullContext,
        ) -> batches.Data:
            self._batch_count += 1

            if self._batch_count == 1:
                # Always return an empty first batch.
                return batches.EmptyBatch()
            elif self._batch_count == 2:
                # Always return a real second batch (otherwise the epoch may end up with
                # nothing but empty batches).
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)

            # Return subsequent batches with 50% success rate.
            if random.random() < 0.5:
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)
            else:
                return batches.EmptyBatch()

    run_id = run_id_lib.RunId.GenerateUnique(
        f"mock{random.randint(0, int(1e6)):06}")

    model = FlakyBatchModel(
        logger=logger,
        graph_db=graph_db,
        run_id=run_id,
    )

    batch_iterator = batches.BatchIterator(
        batches=model.BatchIterator(
            epoch.Type.TRAIN,
            graph_database_reader.BufferedGraphReader(graph_db)),
        graph_count=graph_db.graph_count,
    )

    model.Initialize()
    results = model(epoch_type=epoch.Type.TRAIN,
                    batch_iterator=batch_iterator,
                    logger=logger)
    assert results.batch_count >= 1
示例#2
0
def MakeBatchIterator(
    model: classifier_base.ClassifierBase,
    graph_db: graph_tuple_database.Database,
    splits: Dict[epoch.Type, List[int]],
    epoch_type: epoch.Type,
    ctx: progress.ProgressContext = progress.NullContext,
) -> batches.BatchIterator:
    """Create an iterator over batches for the given epoch type.

  Args:
    model: The model to generate a batch iterator for.
    splits: A mapping from epoch type to a list of split numbers.
    epoch_type: The type of epoch to produce an iterator for.
    ctx: A logging context.

  Returns:
    A batch iterator for feeding into model.RunBatch().
  """
    # Filter the graph database to load graphs from the requested splits.
    if epoch_type == epoch.Type.TRAIN:
        limit = FLAGS.max_train_per_epoch
    elif epoch_type == epoch.Type.VAL:
        limit = FLAGS.max_val_per_epoch
    elif epoch_type == epoch.Type.TEST:
        limit = None  # Never limit the test set.
    else:
        raise NotImplementedError("unreachable")

    splits_for_type = splits[epoch_type]
    ctx.Log(
        3,
        "Using %s graph splits %s",
        epoch_type.name.lower(),
        sorted(splits_for_type),
    )

    if len(splits_for_type) == 1:
        split_filter = (lambda: graph_tuple_database.GraphTuple.split ==
                        splits_for_type[0])
    else:
        split_filter = lambda: graph_tuple_database.GraphTuple.split.in_(
            splits_for_type)

    graph_reader = model.GraphReader(
        epoch_type=epoch_type,
        graph_db=graph_db,
        filters=[split_filter],
        limit=limit,
        ctx=ctx,
    )

    return batches.BatchIterator(
        batches=ppar.ThreadedIterator(
            model.BatchIterator(epoch_type, graph_reader, ctx=ctx),
            max_queue_size=FLAGS.batch_queue_size,
        ),
        graph_count=graph_reader.n,
    )
示例#3
0
def batch_iterator(
    model: MockModel,
    graph_db: graph_tuple_database.Database,
) -> batches.BatchIterator:
    return batches.BatchIterator(
        batches=model.BatchIterator(
            epoch.Type.TRAIN,
            graph_database_reader.BufferedGraphReader(graph_db)),
        graph_count=graph_db.graph_count,
    )