def test_graph_classifier_call( epoch_type: epoch.Type, logger: logging.Logger, graph_classification_graph_db: graph_tuple_database.Database, ): """Test running a graph classifier.""" run_id = run_id_lib.RunId.GenerateUnique( f"mock{random.randint(0, int(1e6)):06}" ) # Create and initialize an untrained model. model = zero_r.ZeroR(logger, graph_classification_graph_db, run_id=run_id) model.Initialize() # Run the model over some random graphs. batch_iterator = batch_iterator_lib.MakeBatchIterator( model=model, graph_db=graph_classification_graph_db, splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],}, epoch_type=epoch_type, ) results = model( epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger, ) assert isinstance(results, epoch.Results) assert results.batch_count
def test_node_classifier_call( epoch_type: epoch.Type, node_classification_graph_db: graph_tuple_database.Database, logger: logging.Logger, ): """Test running a node classifier.""" # Create and initialize an untrained model. model = zero_r.ZeroR(logger, node_classification_graph_db) model.Initialize() # Run the model over some random graphs. batch_iterator = batch_iterator_lib.MakeBatchIterator( model=model, graph_db=node_classification_graph_db, splits={ epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2], }, epoch_type=epoch_type, ) results = model( epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger, ) assert isinstance(results, epoch.Results) assert results.batch_count
def test_load_restore_model_from_checkpoint_smoke_test( logger: logging.Logger, node_classification_graph_db: graph_tuple_database.Database, ): """Test creating and restoring model from checkpoint.""" # Create and initialize an untrained model. model = zero_r.ZeroR(logger, node_classification_graph_db) model.Initialize() # Smoke test save and restore. checkpoint_ref = model.SaveCheckpoint() model.RestoreFrom(checkpoint_ref)
def test_load_restore_model_from_checkpoint_smoke_test( logger: logging.Logger, node_classification_graph_db: graph_tuple_database.Database, ): """Test creating and restoring model from checkpoint.""" run_id = run_id_lib.RunId.GenerateUnique( f"mock{random.randint(0, int(1e6)):06}" ) # Create and initialize an untrained model. model = zero_r.ZeroR(logger, node_classification_graph_db, run_id=run_id) model.Initialize() # Smoke test save and restore. checkpoint_ref = model.SaveCheckpoint() model.RestoreFrom(checkpoint_ref)
def test_load_restore_model_from_checkpoint_smoke_test( <<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py logger: logging.Logger, node_classification_graph_db: graph_tuple_database.Database, ======= logger: logging.Logger, node_y_graph_db: graph_tuple_database.Database, >>>>>>> de933d07a... Add a node text embedding enum.:deeplearning/ml4pl/models/ggnn/ggnn_test.py ): """Test creating and restoring model from checkpoint.""" run_id = run_id_lib.RunId.GenerateUnique( f"mock{random.randint(0, int(1e6)):06}" ) # Create and initialize an untrained model. <<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py model = zero_r.ZeroR(logger, node_classification_graph_db, run_id=run_id) ======= model = ggnn.Ggnn(logger, node_y_graph_db, run_id=run_id) >>>>>>> de933d07a... Add a node text embedding enum.:deeplearning/ml4pl/models/ggnn/ggnn_test.py model.Initialize() # Smoke test save and restore. checkpoint_ref = model.SaveCheckpoint() model.RestoreFrom(checkpoint_ref) def test_node_classifier_call( <<<<<<< HEAD:deeplearning/ml4pl/models/zero_r/zero_r_test.py epoch_type: epoch.Type, node_classification_graph_db: graph_tuple_database.Database, logger: logging.Logger,