예제 #1
0
def test_load_restore_model_from_checkpoint_smoke_test(
  logger: logging.Logger,
  graph_db: graph_tuple_database.Database,
  proto_db: unlabelled_graph_database.Database,
):
  """Test creating and restoring a model from checkpoint."""
  # Create and initialize a model.
  model = node_lstm.NodeLstm(
    logger,
    graph_db,
    proto_db=proto_db,
    batch_size=32,
    padded_sequence_length=10,
    padded_node_sequence_length=10,
  )
  model.Initialize()

  # Create a checkpoint from the model.
  checkpoint_ref = model.SaveCheckpoint()

  # Reset the model state to the checkpoint.
  model.RestoreFrom(checkpoint_ref)

  # Run a test epoch to make sure the restored model works.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=model,
    graph_db=graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch.Type.TEST,
  )
  model(
    epoch_type=epoch.Type.TEST, batch_iterator=batch_iterator, logger=logger,
  )

  # Create a new model instance and restore its state from the checkpoint.
  new_model = node_lstm.NodeLstm(
    logger,
    graph_db,
    proto_db=proto_db,
    batch_size=32,
    padded_sequence_length=10,
    padded_node_sequence_length=10,
  )
  new_model.RestoreFrom(checkpoint_ref)

  # Check that the new model works.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=new_model,
    graph_db=graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch.Type.TEST,
  )
  new_model(
    epoch_type=epoch.Type.TEST, batch_iterator=batch_iterator, logger=logger,
  )
예제 #2
0
파일: run.py 프로젝트: fehrethe/ProGraML
def RunOne(
    model_class,
    graph_db: graph_tuple_database.Database,
    print_header: bool = True,
    print_footer: bool = True,
    ctx: progress.ProgressContext = progress.NullContext,
) -> pd.Series:
    with logger_lib.Logger.FromFlags() as logger:
        logger.ctx = ctx
        model = CreateModel(model_class, graph_db, logger)

        if print_header:
            PrintExperimentHeader(model)

        splits = SplitsFromFlags(graph_db)

        if FLAGS.test_only:
            batch_iterator = batch_iterator_lib.MakeBatchIterator(
                model=model,
                graph_db=graph_db,
                splits=splits,
                epoch_type=epoch.Type.TEST,
            )
            RunEpoch(
                epoch_name="test",
                model=model,
                batch_iterator=batch_iterator,
                epoch_type=epoch.Type.TEST,
                logger=logger,
            )
        else:
            train = Train(model=model,
                          graph_db=graph_db,
                          logger=logger,
                          splits=splits)
            progress.Run(train)
            if train.ctx.i != train.ctx.n:
                raise RunError("Model failed")

        # Get the results for the best epoch.
        epochs = GetModelEpochsTable(model, logger)

        # Select only from the epochs with test accuracy, if available.
        only_with_test_epochs = epochs[
            (epochs["test_accuracy"].astype(str) != "-")
            & (epochs["test_accuracy"].notnull())]
        if len(only_with_test_epochs):
            epochs = only_with_test_epochs

        epochs.reset_index(inplace=True)

        # Select the row with the greatest validation accuracy.
        # TODO(github.com/ChrisCummins/ProGraML/issues/38): Find the memory leak.
        best_epoch = copy.deepcopy(epochs.loc[epochs["val_accuracy"].idxmax()])
        del epochs

        if print_footer:
            PrintExperimentFooter(model, best_epoch)

        return best_epoch
예제 #3
0
def test_graph_classifier_call(
  epoch_type: epoch.Type,
  logger: logging.Logger,
  graph_classification_graph_db: graph_tuple_database.Database,
):
  """Test running a graph classifier."""
  run_id = run_id_lib.RunId.GenerateUnique(
    f"mock{random.randint(0, int(1e6)):06}"
  )

  # Create and initialize an untrained model.
  model = zero_r.ZeroR(logger, graph_classification_graph_db, run_id=run_id)
  model.Initialize()

  # Run the model over some random graphs.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=model,
    graph_db=graph_classification_graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch_type,
  )

  results = model(
    epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger,
  )
  assert isinstance(results, epoch.Results)

  assert results.batch_count
예제 #4
0
def test_graph_classifier_call(
  epoch_type: epoch.Type,
  logger: logging.Logger,
  graph_y_graph_db: graph_tuple_database.Database,
  node_text_embedding_type,
  log1p_graph_x: bool,
):
  """Test running a graph classifier."""
  FLAGS.inst2vec_embeddings = node_text_embedding_type
  FLAGS.log1p_graph_x = log1p_graph_x

  # Create and initialize an untrained model.
  model = ggnn.Ggnn(logger, graph_y_graph_db)
  model.Initialize()

  # Run the model over some random graphs.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=model,
    graph_db=graph_y_graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch_type,
  )

  results = model(
    epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger,
  )

  CheckResultsProperties(results, graph_y_graph_db, epoch_type)
예제 #5
0
def test_node_classifier_call(
    epoch_type: epoch.Type,
    node_classification_graph_db: graph_tuple_database.Database,
    logger: logging.Logger,
):
    """Test running a node classifier."""
    # Create and initialize an untrained model.
    model = zero_r.ZeroR(logger, node_classification_graph_db)
    model.Initialize()

    # Run the model over some random graphs.
    batch_iterator = batch_iterator_lib.MakeBatchIterator(
        model=model,
        graph_db=node_classification_graph_db,
        splits={
            epoch.Type.TRAIN: [0],
            epoch.Type.VAL: [1],
            epoch.Type.TEST: [2],
        },
        epoch_type=epoch_type,
    )

    results = model(
        epoch_type=epoch_type,
        batch_iterator=batch_iterator,
        logger=logger,
    )
    assert isinstance(results, epoch.Results)

    assert results.batch_count
예제 #6
0
파일: run.py 프로젝트: fehrethe/ProGraML
 def MakeBatchIterator(self,
                       epoch_type: epoch.Type) -> batchs.BatchIterator:
     """Construct a batch iterator."""
     return batch_iterator_lib.MakeBatchIterator(
         model=self.model,
         graph_db=self.graph_db,
         splits=self.splits,
         epoch_type=epoch_type,
         ctx=self.ctx,
     )
예제 #7
0
def test_node_classifier_call(
    epoch_type: epoch.Type,
    logger: logging.Logger,
    layer_timesteps: List[str],
    node_y_graph_db: graph_tuple_database.Database,
    node_text_embedding_type,
    unroll_strategy: str,
    log1p_graph_x: bool,
    limit_max_data_flow_steps: bool,
):
    """Test running a node classifier."""
    FLAGS.inst2vec_embeddings = node_text_embedding_type
    FLAGS.unroll_strategy = unroll_strategy
    FLAGS.layer_timesteps = layer_timesteps
    FLAGS.log1p_graph_x = log1p_graph_x
    FLAGS.limit_max_data_flow_steps = limit_max_data_flow_steps

    # Test to handle the unsupported combination of config values.
    if (unroll_strategy == "label_convergence"
            and node_y_graph_db.graph_x_dimensionality) or (
                unroll_strategy == "label_convergence"
                and len(layer_timesteps) > 1):
        with test.Raises(AssertionError):
            ggnn.Ggnn(logger, node_y_graph_db)
        return

    # Create and initialize an untrained model.
    model = ggnn.Ggnn(logger, node_y_graph_db)
    model.Initialize()

    # Run the model over some random graphs.
    batch_iterator = batch_iterator_lib.MakeBatchIterator(
        model=model,
        graph_db=node_y_graph_db,
        splits={
            epoch.Type.TRAIN: [0],
            epoch.Type.VAL: [1],
            epoch.Type.TEST: [2],
        },
        epoch_type=epoch_type,
    )

    results = model(
        epoch_type=epoch_type,
        batch_iterator=batch_iterator,
        logger=logger,
    )

    CheckResultsProperties(results, node_y_graph_db, epoch_type)
예제 #8
0
def test_classifier_call(
    epoch_type: epoch.Type,
    logger: logging.Logger,
    graph_db: graph_tuple_database.Database,
    ir_db: ir_database.Database,
):
    """Test running a graph classifier."""
    run_id = run_id_lib.RunId.GenerateUnique(
        f"mock{random.randint(0, int(1e6)):06}")

    model = graph_lstm.GraphLstm(
        logger,
        graph_db,
        ir_db=ir_db,
        batch_size=8,
        padded_sequence_length=100,
        run_id=run_id,
    )
    model.Initialize()

    batch_iterator = batch_iterator_lib.MakeBatchIterator(
        model=model,
        graph_db=graph_db,
        splits={
            epoch.Type.TRAIN: [0],
            epoch.Type.VAL: [1],
            epoch.Type.TEST: [2],
        },
        epoch_type=epoch_type,
    )

    results = model(
        epoch_type=epoch_type,
        batch_iterator=batch_iterator,
        logger=logger,
    )
    assert isinstance(results, epoch.Results)

    assert results.batch_count

    # We only get loss for training.
    if epoch_type == epoch.Type.TRAIN:
        assert results.has_loss
    else:
        assert not results.has_loss
예제 #9
0
def test_classifier_call(
    epoch_type: epoch.Type,
    logger: logging.Logger,
    graph_db: graph_tuple_database.Database,
    proto_db: unlabelled_graph_database.Database,
):
    """Test running a node classifier."""
    model = node_lstm.NodeLstm(
        logger,
        graph_db,
        proto_db=proto_db,
        batch_size=32,
        padded_sequence_length=100,
        padded_node_sequence_length=50,
    )
    model.Initialize()

    batch_iterator = batch_iterator_lib.MakeBatchIterator(
        model=model,
        graph_db=graph_db,
        splits={
            epoch.Type.TRAIN: [0],
            epoch.Type.VAL: [1],
            epoch.Type.TEST: [2],
        },
        epoch_type=epoch_type,
    )

    results = model(
        epoch_type=epoch_type,
        batch_iterator=batch_iterator,
        logger=logger,
    )
    assert isinstance(results, epoch.Results)

    assert results.batch_count

    # We only get loss for training.
    if epoch_type == epoch.Type.TRAIN:
        assert results.has_loss
    else:
        assert not results.has_loss
예제 #10
0
  node_classification_graph_db: graph_tuple_database.Database,
  logger: logging.Logger,
):
  """Test running a node classifier."""
  run_id = run_id_lib.RunId.GenerateUnique(
    f"mock{random.randint(0, int(1e6)):06}"
  )

  # Create and initialize an untrained model.
  model = zero_r.ZeroR(logger, node_classification_graph_db, run_id=run_id)
  model.Initialize()

  # Run the model over some random graphs.
  batch_iterator = batch_iterator_lib.MakeBatchIterator(
    model=model,
    graph_db=node_classification_graph_db,
    splits={epoch.Type.TRAIN: [0], epoch.Type.VAL: [1], epoch.Type.TEST: [2],},
    epoch_type=epoch_type,
  )

  results = model(
    epoch_type=epoch_type, batch_iterator=batch_iterator, logger=logger,
  )
  assert isinstance(results, epoch.Results)

  assert results.batch_count


def test_graph_classifier_call(
  epoch_type: epoch.Type,
  logger: logging.Logger,
  graph_classification_graph_db: graph_tuple_database.Database,
예제 #11
0
def RunOne(
    model_class,
    graph_db: graph_tuple_database.Database,
    print_header: bool = True,
    print_footer: bool = True,
    ctx: progress.ProgressContext = progress.NullContext,
) -> pd.Series:
    with logger_lib.Logger.FromFlags() as logger:
        logger.ctx = ctx
        model = CreateModel(model_class, graph_db, logger)

        if print_header:
            PrintExperimentHeader(model)

        splits = SplitsFromFlags(graph_db)

        if FLAGS.test_only:
            batch_iterator = batch_iterator_lib.MakeBatchIterator(
                model=model,
                graph_db=graph_db,
                splits=splits,
                epoch_type=epoch.Type.TEST,
            )
            RunEpoch(
                epoch_name="test",
                model=model,
                batch_iterator=batch_iterator,
                epoch_type=epoch.Type.TEST,
                logger=logger,
                epoch_name_prefix=FLAGS.tag[:28] if FLAGS.tag else "",
            )
        else:
            train = TrainValTestLoop(model=model,
                                     graph_db=graph_db,
                                     logger=logger,
                                     splits=splits)
            progress.Run(train)
            if train.ctx.i != train.ctx.n:
                raise RunError("Model failed")

        # Get the results for the best epoch.
        epochs = GetModelEpochsTable(model, logger)

        # Select only from the epochs with test accuracy, if available.
        only_with_test_epochs = epochs[
            (epochs["test_accuracy"].astype(str) != "-")
            & (epochs["test_accuracy"].notnull())]
        if len(only_with_test_epochs):
            epochs = only_with_test_epochs

        epochs.reset_index(inplace=True)

        # When running --test_only on a model without any training, we will have
        # no validation results to return, so just return the first run.
        if len(epochs) == 1:
            return epochs.iloc[0]

        # Select the row with the greatest validation accuracy.
        best_epoch = copy.deepcopy(epochs.loc[epochs["val_accuracy"].idxmax()])
        del epochs

        if print_footer:
            PrintExperimentFooter(model, best_epoch)

        return best_epoch