예제 #1
0
def populated_log_db(
    request,
    generator: random_log_database_generator.RandomLogDatabaseGenerator
) -> DatabaseAndRunIds:
    """A test fixture which yields an empty log database."""
    with testing_databases.DatabaseContext(log_database.Database,
                                           request.param) as db:
        yield DatabaseAndRunIds(db=db,
                                run_ids=generator.PopulateLogDatabase(
                                    db, run_count=10))
예제 #2
0
def disposable_populated_log_db(
  request, generator: random_log_database_generator.RandomLogDatabaseGenerator
) -> DatabaseAndRunIds:
  """Same as populated_log_db, but is generated fresh for each test."""
  with testing_databases.DatabaseContext(
    log_database.Database, request.param
  ) as db:
    yield DatabaseAndRunIds(
      db=db, run_ids=generator.PopulateLogDatabase(db, run_count=10)
    )
예제 #3
0
def two_run_id_session(
  empty_db: log_database.Database,
  generator: random_log_database_generator.RandomLogDatabaseGenerator,
) -> log_database.Database.SessionType:
  """A test fixture which yields a database with two runs."""
  a = generator.CreateRandomRunLogs(run_id=run_id.RunId.GenerateUnique("a"))
  b = generator.CreateRandomRunLogs(run_id=run_id.RunId.GenerateUnique("b"))

  with empty_db.Session() as session:
    session.add_all(a.all + b.all)
    yield DatabaseSessionWithRunLogs(session=session, a=a, b=b)
def test_parameters(
  generator: random_log_database_generator.RandomLogDatabaseGenerator,
  run_id: run_id_lib.RunId,
  max_param_count: int,
  db_session: log_database.Database.SessionType,
):
  """Black-box test of generator properties."""
  logs = generator.CreateRandomRunLogs(
    run_id=run_id, max_param_count=max_param_count
  )
  # We can't test on max_param_count as an upper bound because the log generator
  # can add additional graph_db parameters.
  assert 1 <= len(logs.parameters)
  for param in logs.parameters:
    assert isinstance(param, log_database.Parameter)
    assert param.run_id == run_id

  db_session.add_all(logs.all)
  db_session.commit()
def test_batches(
  generator: random_log_database_generator.RandomLogDatabaseGenerator,
  run_id: run_id_lib.RunId,
  max_epoch_count: int,
  max_batch_count: int,
  db_session: log_database.Database.SessionType,
):
  """Black-box test of generator properties."""
  logs = generator.CreateRandomRunLogs(
    run_id=run_id,
    max_epoch_count=max_epoch_count,
    max_batch_count=max_batch_count,
  )
  assert 2 <= len(logs.batches) <= 3 * max_epoch_count * max_batch_count
  for batch in logs.batches:
    assert isinstance(batch, log_database.Batch)
    assert batch.run_id == run_id

  db_session.add_all(logs.all)
  db_session.commit()
def test_PopulateLogDatabase(
  generator: random_log_database_generator.RandomLogDatabaseGenerator,
  db: log_database.Database,
  run_count: int,
  max_param_count: int,
  max_epoch_count: int,
  max_batch_count: int,
):
  """Test populating databases."""
  generator.PopulateLogDatabase(
    db,
    run_count,
    max_param_count=max_param_count,
    max_epoch_count=max_epoch_count,
    max_batch_count=max_batch_count,
  )
  with db.Session() as session:
    assert (
      session.query(
        sql.func.count(sql.func.distinct(log_database.Parameter.run_id))
      ).scalar()
      == run_count
    )

    assert (
      session.query(
        sql.func.count(sql.func.distinct(log_database.Batch.run_id))
      ).scalar()
      == run_count
    )

    assert (
      session.query(
        sql.func.count(sql.func.distinct(log_database.Checkpoint.run_id))
      ).scalar()
      <= run_count
    )