def SelectTrainingCheckpoint( log_dir: pathlib.Path, ) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]: """Select a checkpoint to load to resume training. Returns: A tuple of <Epoch, Checkpoint> messages. """ epoch_num = -1 for path in (log_dir / "epochs").iterdir(): if path.name.endswith(".EpochList.pbtxt"): epoch = pbutil.FromFile(path, epoch_pb2.EpochList()) if not epoch.epoch[0].train_results.graph_count: continue epoch_num = max(epoch_num, epoch.epoch[0].epoch_num) epoch = pbutil.FromFile( log_dir / "epochs" / f"{epoch_num:03d}.EpochList.pbtxt", epoch_pb2.EpochList(), ) checkpoint = pbutil.FromFile( log_dir / "checkpoints" / f"{epoch_num:03d}.Checkpoint.pb", checkpoint_pb2.Checkpoint(), ) app.Log( 1, "Resuming training from checkpoint %d with val F1 score %.3f", epoch.epoch[0].epoch_num, epoch.epoch[0].val_results.mean_f1, ) return epoch.epoch[0], checkpoint
def SelectTestCheckpoint( log_dir: Path, ) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]: """Select a checkpoint to load for testing. The training checkpoint with the highest validation F1 score is used for testing. Returns: A tuple of <Epoch, Checkpoint> messages. """ best_f1 = -1 best_epoch_num = None for path in (log_dir / "epochs").iterdir(): if path.name.endswith(".EpochList.pbtxt"): epoch = pbutil.FromFile(path, epoch_pb2.EpochList()) f1 = epoch.epoch[0].val_results.mean_f1 epoch_num = epoch.epoch[0].epoch_num if f1 >= best_f1: best_f1 = f1 best_epoch_num = epoch_num epoch = pbutil.FromFile( log_dir / "epochs" / f"{best_epoch_num:03d}.EpochList.pbtxt", epoch_pb2.EpochList(), ) checkpoint = pbutil.FromFile( log_dir / "checkpoints" / f"{best_epoch_num:03d}.Checkpoint.pb", checkpoint_pb2.Checkpoint(), ) logging.info( "Selected best checkpoint %d with val F1 score %.3f", epoch.epoch[0].epoch_num, epoch.epoch[0].val_results.mean_f1, ) return epoch.epoch[0], checkpoint
def SaveCheckpoint(self) -> checkpoint_pb2.Checkpoint: """Construct a checkpoint from the current model state. Returns: A checkpoint reference. """ return checkpoint_pb2.Checkpoint(model_data=pickle.dumps( self.GetModelData()), )
def TestOne( features_list_path: pathlib.Path, features_list_index: int, checkpoint_path: pathlib.Path, ) -> BatchResults: path = pathlib.Path(pathflag.path()) features_list = pbutil.FromFile( features_list_path, program_graph_features_pb2.ProgramGraphFeaturesList(), ) features = features_list.graph[features_list_index] graph_name = features_list_path.name[: -len(".ProgramGraphFeaturesList.pb")] graph = pbutil.FromFile( path / "graphs" / f"{graph_name}.ProgramGraph.pb", program_graph_pb2.ProgramGraph(), ) # Instantiate and restore the model. vocab = vocabulary.LoadVocabulary( path, model_name="cdfg" if FLAGS.cdfg else "programl", max_items=FLAGS.max_vocab_size, target_cumfreq=FLAGS.target_vocab_cumfreq, ) if FLAGS.cdfg: FLAGS.use_position_embeddings = False model = Ggnn( vocabulary=vocab, test_only=True, node_y_dimensionality=2, graph_y_dimensionality=0, graph_x_dimensionality=0, use_selector_embeddings=True, ) checkpoint = pbutil.FromFile(checkpoint_path, checkpoint_pb2.Checkpoint()) model.RestoreCheckpoint(checkpoint) batch = list( DataflowGgnnBatchBuilder( graph_loader=SingleGraphLoader(graph=graph, features=features), vocabulary=vocab, max_node_size=int(1e9), use_cdfg=FLAGS.cdfg, max_batch_count=1, ) )[0] results = model.RunBatch(epoch_pb2.TEST, batch) return AnnotateGraphWithBatchResults(graph, features, results)