def TestDataflowLSTM( path: pathlib.Path, log_dir: pathlib.Path, vocab: Dict[str, int], ): dataflow.PatchWarnings() dataflow.RecordExperimentalSetup(log_dir) # Create the logging directories. assert (log_dir / "epochs").is_dir() assert (log_dir / "checkpoints").is_dir() (log_dir / "graph_loader").mkdir(exist_ok=True) # Create the model, defining the shape of the graphs that it will process. # # For these data flow experiments, our graphs contain per-node binary # classification targets (e.g. reachable / not-reachable). model = Lstm( vocabulary=vocab, test_only=True, node_y_dimensionality=2, ) restored_epoch, checkpoint = dataflow.SelectTestCheckpoint(log_dir) model.RestoreCheckpoint(checkpoint) batches = MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.TEST, model=model, min_graph_count=1, ) start_time = time.time() test_results = model.RunBatches(epoch_pb2.TEST, batches, log_prefix="Test") epoch = epoch_pb2.EpochList(epoch=[ epoch_pb2.Epoch( walltime_seconds=time.time() - start_time, epoch_num=restored_epoch.epoch_num, test_results=test_results, ) ]) print(epoch, end="") epoch_path = log_dir / "epochs" / "TEST.EpochList.pbtxt" pbutil.ToFile(epoch, epoch_path) logging.info("Wrote %s", epoch_path)
def Main(): # NOTE(github.com/ChrisCummins/ProGraML/issues/13): F1 score computation # warns that it is undefined when there are missing instances from a class, # which is fine for our usage. warnings.filterwarnings("ignore", category=UndefinedMetricWarning) with data_directory() as path: Print("=== BENCHMARK 1: Loading graphs from filesystem ===") graph_loader = GraphLoader(path) graphs = ppar.ThreadedIterator(graph_loader, max_queue_size=100) with prof.Profile("Benchmark graph loader"): for _ in tqdm(graphs, unit=" graphs"): pass Print("=== BENCHMARK 2: Batch construction ===") model = Lstm(vocabulary=Vocab(), node_y_dimensionality=2) batches = BatchBuilder(model, GraphLoader(path), Vocab()) batches = ppar.ThreadedIterator(batches, max_queue_size=100) cached_batches = [] with prof.Profile("Benchmark batch construction"): for batch in tqdm(batches, unit=" batches"): cached_batches.append(batch) Print("=== BENCHMARK 3: Model training ===") model.Initialize() model.model.summary() with prof.Profile("Benchmark training (prebuilt batches)"): model.RunBatches( epoch_pb2.TRAIN, cached_batches[: FLAGS.train_batch_count], log_prefix="Train", total_graph_count=sum( b.graph_count for b in cached_batches[: FLAGS.train_batch_count] ), ) with prof.Profile("Benchmark training"): model.RunBatches( epoch_pb2.TRAIN, BatchBuilder( model, GraphLoader(path), Vocab(), FLAGS.train_batch_count ), log_prefix="Train", ) Print("=== BENCHMARK 4: Model inference ===") model = Lstm( vocabulary=Vocab(), node_y_dimensionality=2, test_only=True, ) model.Initialize() with prof.Profile("Benchmark inference (prebuilt batches)"): model.RunBatches( epoch_pb2.TEST, cached_batches[: FLAGS.test_batch_count], log_prefix="Val", total_graph_count=sum( b.graph_count for b in cached_batches[: FLAGS.test_batch_count] ), ) with prof.Profile("Benchmark inference"): model.RunBatches( epoch_pb2.TEST, BatchBuilder(model, GraphLoader(path), Vocab(), FLAGS.test_batch_count), log_prefix="Val", )
def TrainDataflowLSTM( path: pathlib.Path, vocab: Dict[str, int], val_seed: int, restore_from: pathlib.Path, ) -> pathlib.Path: if not path.is_dir(): raise FileNotFoundError(path) if restore_from: log_dir = restore_from else: # Create the logging directories. log_dir = dataflow.CreateLoggingDirectories( dataset_root=path, model_name="inst2vec", analysis=FLAGS.analysis, run_id=FLAGS.run_id, ) dataflow.PatchWarnings() dataflow.RecordExperimentalSetup(log_dir) # Cumulative totals for training graph counts at each "epoch". train_graph_counts = [int(x) for x in FLAGS.train_graph_counts] train_graph_cumsums = np.array(train_graph_counts, dtype=np.int32) # The number of training graphs in each "epoch". train_graph_counts = train_graph_cumsums - np.concatenate( ([0], train_graph_counts[:-1]) ) # Create the model, defining the shape of the graphs that it will process. # # For these data flow experiments, our graphs contain per-node binary # classification targets (e.g. reachable / not-reachable). model = Lstm(vocabulary=vocab, test_only=False, node_y_dimensionality=2,) if restore_from: # Pick up training where we left off. restored_epoch, checkpoint = dataflow.SelectTrainingCheckpoint(log_dir) # Skip the epochs that we have already done. # This requires that --train_graph_counts is the same as it was in the # run that we are resuming! start_epoch_step = restored_epoch.epoch_num start_graph_cumsum = sum(train_graph_counts[:start_epoch_step]) train_graph_counts = train_graph_counts[start_epoch_step:] train_graph_cumsums = train_graph_cumsums[start_epoch_step:] model.RestoreCheckpoint(checkpoint) else: # Else initialize a new model. model.Initialize() start_epoch_step, start_graph_cumsum = 1, 0 model.model.summary() # Create training batches and split into epochs. epochs = EpochBatchIterator( MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.TRAIN, model=model, seed=val_seed, ), train_graph_counts, start_graph_count=start_graph_cumsum, ) # Read val batches asynchronously val_batches = AsyncBatchBuilder( batch_builder=MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.VAL, model=model, min_graph_count=FLAGS.val_graph_count, max_graph_count=FLAGS.val_graph_count, seed=val_seed, ) ) for ( epoch_step, (train_graph_count, train_graph_cumsum, train_batches), ) in enumerate(epochs, start=start_epoch_step): start_time = time.time() hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs" train_results = model.RunBatches( epoch_pb2.TRAIN, train_batches, log_prefix=f"Train to {hr_graph_cumsum}", total_graph_count=train_graph_count, ) val_results = model.RunBatches( epoch_pb2.VAL, val_batches.batches, log_prefix=f"Val at {hr_graph_cumsum}", total_graph_count=FLAGS.val_graph_count, ) # Write the epoch to file as an epoch list. This may seem redundant since # epoch list contains a single item, but it means that we can easily # concatenate a sequence of these epoch protos to produce a valid epoch # list using: `cat *.EpochList.pbtxt > epochs.pbtxt` epoch = epoch_pb2.EpochList( epoch=[ epoch_pb2.Epoch( walltime_seconds=time.time() - start_time, epoch_num=epoch_step, train_results=train_results, val_results=val_results, ) ] ) print(epoch, end="") epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt" pbutil.ToFile(epoch, epoch_path) app.Log(1, "Wrote %s", epoch_path) checkpoint_path = ( log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb" ) pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path) return log_dir