예제 #1
0
def MakeBatchIterator(
    model: classifier_base.ClassifierBase,
    graph_db: graph_tuple_database.Database,
    splits: Dict[epoch.Type, List[int]],
    epoch_type: epoch.Type,
    ctx: progress.ProgressContext = progress.NullContext,
) -> batches.BatchIterator:
    """Create an iterator over batches for the given epoch type.

  Args:
    model: The model to generate a batch iterator for.
    splits: A mapping from epoch type to a list of split numbers.
    epoch_type: The type of epoch to produce an iterator for.
    ctx: A logging context.

  Returns:
    A batch iterator for feeding into model.RunBatch().
  """
    # Filter the graph database to load graphs from the requested splits.
    if epoch_type == epoch.Type.TRAIN:
        limit = FLAGS.max_train_per_epoch
    elif epoch_type == epoch.Type.VAL:
        limit = FLAGS.max_val_per_epoch
    elif epoch_type == epoch.Type.TEST:
        limit = None  # Never limit the test set.
    else:
        raise NotImplementedError("unreachable")

    splits_for_type = splits[epoch_type]
    ctx.Log(
        3,
        "Using %s graph splits %s",
        epoch_type.name.lower(),
        sorted(splits_for_type),
    )

    if len(splits_for_type) == 1:
        split_filter = (lambda: graph_tuple_database.GraphTuple.split ==
                        splits_for_type[0])
    else:
        split_filter = lambda: graph_tuple_database.GraphTuple.split.in_(
            splits_for_type)

    graph_reader = model.GraphReader(
        epoch_type=epoch_type,
        graph_db=graph_db,
        filters=[split_filter],
        limit=limit,
        ctx=ctx,
    )

    return batches.BatchIterator(
        batches=ppar.ThreadedIterator(
            model.BatchIterator(epoch_type, graph_reader, ctx=ctx),
            max_queue_size=FLAGS.batch_queue_size,
        ),
        graph_count=graph_reader.n,
    )
예제 #2
0
파일: run.py 프로젝트: fehrethe/ProGraML
def PrintExperimentHeader(model: classifier_base.ClassifierBase) -> None:
    print("==================================================================")
    print(pyfiglet.figlet_format(model.run_id.script_name))
    print("Run ID:", model.run_id)
    params = model.parameters[["type", "name", "value"]]
    params = params.rename(columns=({"type": "parameter"}))
    print(pdutil.FormatDataFrameAsAsciiTable(params))
    print()
    print(model.Summary())
    print("==================================================================")