def MakeBatchIterator( model: classifier_base.ClassifierBase, graph_db: graph_tuple_database.Database, splits: Dict[epoch.Type, List[int]], epoch_type: epoch.Type, ctx: progress.ProgressContext = progress.NullContext, ) -> batches.BatchIterator: """Create an iterator over batches for the given epoch type. Args: model: The model to generate a batch iterator for. splits: A mapping from epoch type to a list of split numbers. epoch_type: The type of epoch to produce an iterator for. ctx: A logging context. Returns: A batch iterator for feeding into model.RunBatch(). """ # Filter the graph database to load graphs from the requested splits. if epoch_type == epoch.Type.TRAIN: limit = FLAGS.max_train_per_epoch elif epoch_type == epoch.Type.VAL: limit = FLAGS.max_val_per_epoch elif epoch_type == epoch.Type.TEST: limit = None # Never limit the test set. else: raise NotImplementedError("unreachable") splits_for_type = splits[epoch_type] ctx.Log( 3, "Using %s graph splits %s", epoch_type.name.lower(), sorted(splits_for_type), ) if len(splits_for_type) == 1: split_filter = (lambda: graph_tuple_database.GraphTuple.split == splits_for_type[0]) else: split_filter = lambda: graph_tuple_database.GraphTuple.split.in_( splits_for_type) graph_reader = model.GraphReader( epoch_type=epoch_type, graph_db=graph_db, filters=[split_filter], limit=limit, ctx=ctx, ) return batches.BatchIterator( batches=ppar.ThreadedIterator( model.BatchIterator(epoch_type, graph_reader, ctx=ctx), max_queue_size=FLAGS.batch_queue_size, ), graph_count=graph_reader.n, )
def PrintExperimentHeader(model: classifier_base.ClassifierBase) -> None: print("==================================================================") print(pyfiglet.figlet_format(model.run_id.script_name)) print("Run ID:", model.run_id) params = model.parameters[["type", "name", "value"]] params = params.rename(columns=({"type": "parameter"})) print(pdutil.FormatDataFrameAsAsciiTable(params)) print() print(model.Summary()) print("==================================================================")