コード例 #1
0
ファイル: zero_r_test.py プロジェクト: monperrus/ProGraML
def test_graph_classifier_call(
  epoch_type: epoch.Type,
  logger: logging.Logger,
  graph_classification_graph_db: graph_tuple_database.Database,
):
  """Test running a graph classifier."""
  run_id = run_id_lib.RunId.GenerateUnique(
    f"mock{random.randint(0, int(1e6)):06}"
  )

  # Create and initialize an untrained model.
  model = zero_r.ZeroR(logger, graph_classification_graph_db, run_id=run_id)
  model.Initialize()

  # Run the model over some random graphs.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=model,
    graph_db=graph_classification_graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch_type,
  )

  results = model(
    epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger,
  )
  assert isinstance(results, epoch.Results)

  assert results.batch_count
コード例 #2
0
def test_node_classifier_call(
    epoch_type: epoch.Type,
    node_classification_graph_db: graph_tuple_database.Database,
    logger: logging.Logger,
):
    """Test running a node classifier."""
    # Create and initialize an untrained model.
    model = zero_r.ZeroR(logger, node_classification_graph_db)
    model.Initialize()

    # Run the model over some random graphs.
    batch_iterator = batch_iterator_lib.MakeBatchIterator(
        model=model,
        graph_db=node_classification_graph_db,
        splits={
            epoch.Type.TRAIN: [0],
            epoch.Type.VAL: [1],
            epoch.Type.TEST: [2],
        },
        epoch_type=epoch_type,
    )

    results = model(
        epoch_type=epoch_type,
        batch_iterator=batch_iterator,
        logger=logger,
    )
    assert isinstance(results, epoch.Results)

    assert results.batch_count
コード例 #3
0
def test_load_restore_model_from_checkpoint_smoke_test(
  logger: logging.Logger,
  node_classification_graph_db: graph_tuple_database.Database,
):
  """Test creating and restoring model from checkpoint."""
  # Create and initialize an untrained model.
  model = zero_r.ZeroR(logger, node_classification_graph_db)
  model.Initialize()

  # Smoke test save and restore.
  checkpoint_ref = model.SaveCheckpoint()
  model.RestoreFrom(checkpoint_ref)
コード例 #4
0
ファイル: zero_r_test.py プロジェクト: monperrus/ProGraML
def test_load_restore_model_from_checkpoint_smoke_test(
  logger: logging.Logger,
  node_classification_graph_db: graph_tuple_database.Database,
):
  """Test creating and restoring model from checkpoint."""
  run_id = run_id_lib.RunId.GenerateUnique(
    f"mock{random.randint(0, int(1e6)):06}"
  )

  # Create and initialize an untrained model.
  model = zero_r.ZeroR(logger, node_classification_graph_db, run_id=run_id)
  model.Initialize()

  # Smoke test save and restore.
  checkpoint_ref = model.SaveCheckpoint()
  model.RestoreFrom(checkpoint_ref)
コード例 #5
0
ファイル: zero_r_test.py プロジェクト: fehrethe/ProGraML
def test_load_restore_model_from_checkpoint_smoke_test(
<<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py
  logger: logging.Logger,
  node_classification_graph_db: graph_tuple_database.Database,
=======
  logger: logging.Logger, node_y_graph_db: graph_tuple_database.Database,
>>>>>>> de933d07a... Add a node text embedding enum.:deeplearning/ml4pl/models/ggnn/ggnn_test.py
):
  """Test creating and restoring model from checkpoint."""
  run_id = run_id_lib.RunId.GenerateUnique(
    f"mock{random.randint(0, int(1e6)):06}"
  )

  # Create and initialize an untrained model.
<<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py
  model = zero_r.ZeroR(logger, node_classification_graph_db, run_id=run_id)
=======
  model = ggnn.Ggnn(logger, node_y_graph_db, run_id=run_id)
>>>>>>> de933d07a... Add a node text embedding enum.:deeplearning/ml4pl/models/ggnn/ggnn_test.py
  model.Initialize()

  # Smoke test save and restore.
  checkpoint_ref = model.SaveCheckpoint()
  model.RestoreFrom(checkpoint_ref)


def test_node_classifier_call(
<<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py
  epoch_type: epoch.Type,
  node_classification_graph_db: graph_tuple_database.Database,
  logger: logging.Logger,