def SelectTrainingCheckpoint( log_dir: pathlib.Path, ) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]: """Select a checkpoint to load to resume training. Returns: A tuple of <Epoch, Checkpoint> messages. """ epoch_num = -1 for path in (log_dir / "epochs").iterdir(): if path.name.endswith(".EpochList.pbtxt"): epoch = pbutil.FromFile(path, epoch_pb2.EpochList()) if not epoch.epoch[0].train_results.graph_count: continue epoch_num = max(epoch_num, epoch.epoch[0].epoch_num) epoch = pbutil.FromFile( log_dir / "epochs" / f"{epoch_num:03d}.EpochList.pbtxt", epoch_pb2.EpochList(), ) checkpoint = pbutil.FromFile( log_dir / "checkpoints" / f"{epoch_num:03d}.Checkpoint.pb", checkpoint_pb2.Checkpoint(), ) app.Log( 1, "Resuming training from checkpoint %d with val F1 score %.3f", epoch.epoch[0].epoch_num, epoch.epoch[0].val_results.mean_f1, ) return epoch.epoch[0], checkpoint
def SelectTestCheckpoint( log_dir: Path, ) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]: """Select a checkpoint to load for testing. The training checkpoint with the highest validation F1 score is used for testing. Returns: A tuple of <Epoch, Checkpoint> messages. """ best_f1 = -1 best_epoch_num = None for path in (log_dir / "epochs").iterdir(): if path.name.endswith(".EpochList.pbtxt"): epoch = pbutil.FromFile(path, epoch_pb2.EpochList()) f1 = epoch.epoch[0].val_results.mean_f1 epoch_num = epoch.epoch[0].epoch_num if f1 >= best_f1: best_f1 = f1 best_epoch_num = epoch_num epoch = pbutil.FromFile( log_dir / "epochs" / f"{best_epoch_num:03d}.EpochList.pbtxt", epoch_pb2.EpochList(), ) checkpoint = pbutil.FromFile( log_dir / "checkpoints" / f"{best_epoch_num:03d}.Checkpoint.pb", checkpoint_pb2.Checkpoint(), ) logging.info( "Selected best checkpoint %d with val F1 score %.3f", epoch.epoch[0].epoch_num, epoch.epoch[0].val_results.mean_f1, ) return epoch.epoch[0], checkpoint
def ReadEpochLogs(path: Path) -> Optional[epoch_pb2.EpochList]: if not (path / "epochs").is_dir(): return None epochs = [] for path in (path / "epochs").iterdir(): epoch = pbutil.FromFile(path, epoch_pb2.EpochList()) # Skip files without data. if not len(epoch.epoch): continue epochs += list(epoch.epoch) return epoch_pb2.EpochList(epoch=sorted(epochs, key=lambda x: x.epoch_num))
def TestDataflowGGNN( path: pathlib.Path, log_dir: pathlib.Path, analysis: str, vocab: Dict[str, int], limit_max_data_flow_steps: bool, batch_size: int, use_cdfg: bool, ): dataflow.PatchWarnings() dataflow.RecordExperimentalSetup(log_dir) # Create the logging directories. assert (log_dir / "epochs").is_dir() assert (log_dir / "checkpoints").is_dir() (log_dir / "graph_loader").mkdir(exist_ok=True) # Create the model, defining the shape of the graphs that it will process. # # For these data flow experiments, our graphs contain per-node binary # classification targets (e.g. reachable / not-reachable). model = Ggnn( vocabulary=vocab, test_only=True, node_y_dimensionality=2, graph_y_dimensionality=0, graph_x_dimensionality=0, use_selector_embeddings=True, ) restored_epoch, checkpoint = dataflow.SelectTestCheckpoint(log_dir) model.RestoreCheckpoint(checkpoint) batches = MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.TEST, analysis=analysis, model=model, batch_size=batch_size, use_cdfg=use_cdfg, # Specify that we require at least one graph, as the default (no min) will # loop forever. min_graph_count=1, limit_max_data_flow_steps=limit_max_data_flow_steps, ) start_time = time.time() test_results = model.RunBatches(epoch_pb2.TEST, batches, log_prefix="Test") epoch = epoch_pb2.EpochList(epoch=[ epoch_pb2.Epoch( walltime_seconds=time.time() - start_time, epoch_num=restored_epoch.epoch_num, test_results=test_results, ) ]) print(epoch, end="") epoch_path = log_dir / "epochs" / "TEST.EpochList.pbtxt" pbutil.ToFile(epoch, epoch_path) logging.info("Wrote %s", epoch_path)
def TestDataflowLSTM( path: pathlib.Path, log_dir: pathlib.Path, vocab: Dict[str, int], ): dataflow.PatchWarnings() dataflow.RecordExperimentalSetup(log_dir) # Create the logging directories. assert (log_dir / "epochs").is_dir() assert (log_dir / "checkpoints").is_dir() (log_dir / "graph_loader").mkdir(exist_ok=True) # Create the model, defining the shape of the graphs that it will process. # # For these data flow experiments, our graphs contain per-node binary # classification targets (e.g. reachable / not-reachable). model = Lstm( vocabulary=vocab, test_only=True, node_y_dimensionality=2, ) restored_epoch, checkpoint = dataflow.SelectTestCheckpoint(log_dir) model.RestoreCheckpoint(checkpoint) batches = MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.TEST, model=model, min_graph_count=1, ) start_time = time.time() test_results = model.RunBatches(epoch_pb2.TEST, batches, log_prefix="Test") epoch = epoch_pb2.EpochList(epoch=[ epoch_pb2.Epoch( walltime_seconds=time.time() - start_time, epoch_num=restored_epoch.epoch_num, test_results=test_results, ) ]) print(epoch, end="") epoch_path = log_dir / "epochs" / "TEST.EpochList.pbtxt" pbutil.ToFile(epoch, epoch_path) logging.info("Wrote %s", epoch_path)
def TrainDataflowGGNN( path: pathlib.Path, analysis: str, vocab: Dict[str, int], limit_max_data_flow_steps: bool, train_graph_counts: List[int], val_graph_count: int, val_seed: int, batch_size: int, use_cdfg: bool, run_id: Optional[str] = None, restore_from: pathlib.Path = None, ) -> pathlib.Path: if not path.is_dir(): raise FileNotFoundError(path) if restore_from: log_dir = restore_from else: # Create the logging directories. log_dir = dataflow.CreateLoggingDirectories( dataset_root=path, model_name="cdfg" if use_cdfg else "programl", analysis=analysis, run_id=run_id, ) dataflow.PatchWarnings() dataflow.RecordExperimentalSetup(log_dir) # Cumulative totals for training graph counts at each "epoch". train_graph_cumsums = np.array(train_graph_counts, dtype=np.int32) # The number of training graphs in each "epoch". train_graph_counts = train_graph_cumsums - np.concatenate( ([0], train_graph_counts[:-1]) ) # Create the model, defining the shape of the graphs that it will process. # # For these data flow experiments, our graphs contain per-node binary # classification targets (e.g. reachable / not-reachable). model = Ggnn( vocabulary=vocab, test_only=False, node_y_dimensionality=2, graph_y_dimensionality=0, graph_x_dimensionality=0, use_selector_embeddings=True, ) if restore_from: # Pick up training where we left off. restored_epoch, checkpoint = dataflow.SelectTrainingCheckpoint(log_dir) # Skip the epochs that we have already done. # This requires that --train_graph_counts is the same as it was in the # run that we are resuming! start_epoch_step = restored_epoch.epoch_num start_graph_cumsum = sum(train_graph_counts[:start_epoch_step]) train_graph_counts = train_graph_counts[start_epoch_step:] model.RestoreCheckpoint(checkpoint) else: # Else initialize a new model. model.Initialize() start_epoch_step, start_graph_cumsum = 1, 0 app.Log( 1, "GGNN has %s training params", humanize.Commas(model.trainable_parameter_count), ) # Create training batches and split into epochs. epochs = EpochBatchIterator( MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.TRAIN, analysis=analysis, model=model, batch_size=batch_size, use_cdfg=use_cdfg, limit_max_data_flow_steps=limit_max_data_flow_steps, ), train_graph_counts, start_graph_count=start_graph_cumsum, ) # Read val batches asynchronously. val_batches = AsyncBatchBuilder( MakeBatchBuilder( dataset_root=path, log_dir=log_dir, epoch_type=epoch_pb2.VAL, analysis=analysis, model=model, batch_size=batch_size, use_cdfg=use_cdfg, limit_max_data_flow_steps=limit_max_data_flow_steps, min_graph_count=val_graph_count, max_graph_count=val_graph_count, seed=val_seed, ) ) for ( epoch_step, (train_graph_count, train_graph_cumsum, train_batches), ) in enumerate(epochs, start=start_epoch_step): start_time = time.time() hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs" train_results = model.RunBatches( epoch_pb2.TRAIN, train_batches, log_prefix=f"Train to {hr_graph_cumsum}", total_graph_count=train_graph_count, ) val_results = model.RunBatches( epoch_pb2.VAL, val_batches.batches, log_prefix=f"Val at {hr_graph_cumsum}", total_graph_count=val_graph_count, ) # Write the epoch to file as an epoch list. This may seem redundant since # epoch list contains a single item, but it means that we can easily # concatenate a sequence of these epoch protos to produce a valid epoch # list using: `cat *.EpochList.pbtxt > epochs.pbtxt` epoch = epoch_pb2.EpochList( epoch=[ epoch_pb2.Epoch( walltime_seconds=time.time() - start_time, epoch_num=epoch_step, train_results=train_results, val_results=val_results, ) ] ) print(epoch, end="") epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt" pbutil.ToFile(epoch, epoch_path) app.Log(1, "Wrote %s", epoch_path) checkpoint_path = ( log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb" ) pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path) return log_dir