Exemple #1
0
def test_Run_test_only(
  disposable_log_db: log_database.Database,
  graph_db: graph_tuple_database.Database,
  k_fold: bool,
):
  """Test the run.Run() method."""
  log_db = disposable_log_db

  # Set the flags that determine the behaviour of Run().
  FLAGS.graph_db = flags_parsers.DatabaseFlag(
    graph_tuple_database.Database, graph_db.url, must_exist=True
  )
  FLAGS.log_db = flags_parsers.DatabaseFlag(
    log_database.Database, log_db.url, must_exist=True
  )
  FLAGS.test_only = True
  FLAGS.k_fold = k_fold

  run.Run(MockModel)

  # Test that k-fold produces multiple runs.
  assert log_db.run_count == graph_db.split_count if k_fold else 1

  run_ids = log_db.run_ids
  for run_id in run_ids:
    logs = log_analysis.RunLogAnalyzer(log_db=log_db, run_id=run_id)
    epochs = logs.tables["epochs"]

    # Check that we performed as many epochs as expected.
    assert 1 == len(epochs)
    test_count = len(epochs[epochs["test_accuracy"].notnull()])
    # Check that we produced a test result.
    assert test_count == 1
Exemple #2
0
def test_RunLogAnalyser_best_epoch_num(populated_log_db: DatabaseAndRunIds,
                                       metric: str):
    """Black-box test that run log properties work."""
    for run_id in populated_log_db.run_ids:
        run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id)
        try:
            assert run.GetBestEpochNum(metric=metric)
        except ValueError as e:
            # Some metrics will raise an error if they are not met. This is fine.
            assert str(e) == f"No {run_id} epochs reached {metric}"
Exemple #3
0
def test_Run(
  disposable_log_db: log_database.Database,
  graph_db: graph_tuple_database.Database,
  k_fold: bool,
  run_with_memory_profiler: bool,
  test_on: str,
  stop_at: List[str],
):
  """Test the run.Run() method."""
  log_db = disposable_log_db

  # Set the flags that determine the behaviour of Run().
  FLAGS.graph_db = flags_parsers.DatabaseFlag(
    graph_tuple_database.Database, graph_db.url, must_exist=True
  )
  FLAGS.log_db = flags_parsers.DatabaseFlag(
    log_database.Database, log_db.url, must_exist=True
  )
  FLAGS.epoch_count = 3
  FLAGS.k_fold = k_fold
  FLAGS.run_with_memory_profiler = run_with_memory_profiler
  FLAGS.test_on = test_on
  FLAGS.stop_at = stop_at

  run.Run(MockModel)

  # Test that k-fold produces multiple runs.
  assert log_db.run_count == graph_db.split_count if k_fold else 1

  run_ids = log_db.run_ids
  for run_id in run_ids:
    logs = log_analysis.RunLogAnalyzer(log_db=log_db, run_id=run_id)
    epochs = logs.tables["epochs"]

    # Check that we performed as many epochs as expected. We can't check the
    # exact value because of --stop_at options.
    assert 1 <= len(epochs) <= FLAGS.epoch_count

    test_count = len(epochs[epochs["test_accuracy"].notnull()])

    # Test that the number of test epochs matches the expected amount depending
    # on --test_on flag.
    if test_on == "none":
      assert test_count == 0
    elif test_on == "best":
      assert test_count == 1
    elif test_on == "improvement":
      assert test_count >= 1
    elif test_on == "improvement_and_last":
      assert test_count >= 1
Exemple #4
0
  def __init__(
    self,
    log_db: log_database.Database,
    checkpoint_ref: checkpoints.CheckpointReference,
    epoch_type: epoch.Type,
    outdir: pathlib.Path,
    export_batches_per_query: int = 128,
  ):
    """Constructor.

    Args:
      log_db: The database to export logs from.
      checkpoint_ref: A run ID and epoch number.
      epoch_type: The type of epoch to export graphs from.
      outdir: The directory to write results to.
    """
    self.log_db = log_db
    self.logger = logger_lib.Logger(self.log_db)
    self.checkpoint = checkpoint_ref
    self.epoch_type = epoch_type
    self.analyzer = log_analysis.RunLogAnalyzer(
      self.log_db, self.checkpoint.run_id
    )
    self.export_batches_per_query = export_batches_per_query

    self.outdir = outdir
    self.outdir.mkdir(parents=True, exist_ok=True)
    (self.outdir / "graph_stats").mkdir(exist_ok=True)
    (self.outdir / "in_graphs").mkdir(exist_ok=True)
    (self.outdir / "out_graphs").mkdir(exist_ok=True)

    # Count the total number of graphs to export cross all batches.
    with self.log_db.Session() as session:
      num_graphs = self.FilterBatchesQuery(
        session.query(sql.func.sum(log_database.Batch.graph_count))
      ).scalar()

    if not num_graphs:
      raise ValueError("No graphs found!")

    super(BatchDetailsExporter, self).__init__(
      name=f"export {self.checkpoint} graphs", unit="graphs", i=0, n=num_graphs,
    )
Exemple #5
0
def test_GetInputOutputGraphs(populated_log_db: DatabaseAndRunIds):
    """Test reconstructing graphs from a detailed batch."""
    # Select a random run to analyze.
    run_id = random.choice(populated_log_db.run_ids)
    run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id)

    with populated_log_db.db.Session() as session:
        # Select some random detailed batches to reconstruct the graphs of.
        detailed_batches = (session.query(log_database.Batch).join(
            log_database.BatchDetails).options(
                sql.orm.joinedload(log_database.Batch.details)).order_by(
                    populated_log_db.db.Random()).limit(50).all())
        # Sanity check that there are detailed batches.
        assert detailed_batches

    for batch in detailed_batches:
        input_output_graphs = list(run.GetInputOutputGraphs(batch))
        # Check that the number of input_output_graphs matches the size of the
        # batch.
        assert len(input_output_graphs) == len(batch.graph_ids)
Exemple #6
0
def test_RunLogAnalyser_smoke_tests(populated_log_db: DatabaseAndRunIds):
    """Black-box test that run log properties work."""
    for run_id in populated_log_db.run_ids:
        run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id)
        assert run.graph_db
        assert run.tables.keys() == {"parameters", "epochs", "runs", "tags"}
Exemple #7
0
def test_RunLogAnalyser_empty_db(empty_log_db: log_database.Database):
    """Test that cannot analyse non-existing run."""
    with test.Raises(ValueError):
        log_analysis.RunLogAnalyzer(empty_log_db,
                                    run_id.RunId.GenerateUnique("foo"))