示例#1
0
def SelectTrainingCheckpoint(
    log_dir: pathlib.Path,
) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]:
    """Select a checkpoint to load to resume training.

    Returns:
      A tuple of <Epoch, Checkpoint> messages.
    """
    epoch_num = -1
    for path in (log_dir / "epochs").iterdir():
        if path.name.endswith(".EpochList.pbtxt"):
            epoch = pbutil.FromFile(path, epoch_pb2.EpochList())
            if not epoch.epoch[0].train_results.graph_count:
                continue
            epoch_num = max(epoch_num, epoch.epoch[0].epoch_num)

    epoch = pbutil.FromFile(
        log_dir / "epochs" / f"{epoch_num:03d}.EpochList.pbtxt",
        epoch_pb2.EpochList(),
    )
    checkpoint = pbutil.FromFile(
        log_dir / "checkpoints" / f"{epoch_num:03d}.Checkpoint.pb",
        checkpoint_pb2.Checkpoint(),
    )
    app.Log(
        1,
        "Resuming training from checkpoint %d with val F1 score %.3f",
        epoch.epoch[0].epoch_num,
        epoch.epoch[0].val_results.mean_f1,
    )
    return epoch.epoch[0], checkpoint
示例#2
0
def SelectTestCheckpoint(
    log_dir: Path,
) -> Tuple[epoch_pb2.Epoch, checkpoint_pb2.Checkpoint]:
    """Select a checkpoint to load for testing.

    The training checkpoint with the highest validation F1 score is used for
    testing.

    Returns:
      A tuple of <Epoch, Checkpoint> messages.
    """
    best_f1 = -1
    best_epoch_num = None
    for path in (log_dir / "epochs").iterdir():
        if path.name.endswith(".EpochList.pbtxt"):
            epoch = pbutil.FromFile(path, epoch_pb2.EpochList())
            f1 = epoch.epoch[0].val_results.mean_f1
            epoch_num = epoch.epoch[0].epoch_num
            if f1 >= best_f1:
                best_f1 = f1
                best_epoch_num = epoch_num
    epoch = pbutil.FromFile(
        log_dir / "epochs" / f"{best_epoch_num:03d}.EpochList.pbtxt",
        epoch_pb2.EpochList(),
    )
    checkpoint = pbutil.FromFile(
        log_dir / "checkpoints" / f"{best_epoch_num:03d}.Checkpoint.pb",
        checkpoint_pb2.Checkpoint(),
    )
    logging.info(
        "Selected best checkpoint %d with val F1 score %.3f",
        epoch.epoch[0].epoch_num,
        epoch.epoch[0].val_results.mean_f1,
    )
    return epoch.epoch[0], checkpoint
示例#3
0
    def SaveCheckpoint(self) -> checkpoint_pb2.Checkpoint:
        """Construct a checkpoint from the current model state.

        Returns:
          A checkpoint reference.
        """
        return checkpoint_pb2.Checkpoint(model_data=pickle.dumps(
            self.GetModelData()), )
示例#4
0
def TestOne(
    features_list_path: pathlib.Path,
    features_list_index: int,
    checkpoint_path: pathlib.Path,
) -> BatchResults:
    path = pathlib.Path(pathflag.path())

    features_list = pbutil.FromFile(
        features_list_path,
        program_graph_features_pb2.ProgramGraphFeaturesList(),
    )
    features = features_list.graph[features_list_index]

    graph_name = features_list_path.name[: -len(".ProgramGraphFeaturesList.pb")]
    graph = pbutil.FromFile(
        path / "graphs" / f"{graph_name}.ProgramGraph.pb",
        program_graph_pb2.ProgramGraph(),
    )

    # Instantiate and restore the model.
    vocab = vocabulary.LoadVocabulary(
        path,
        model_name="cdfg" if FLAGS.cdfg else "programl",
        max_items=FLAGS.max_vocab_size,
        target_cumfreq=FLAGS.target_vocab_cumfreq,
    )

    if FLAGS.cdfg:
        FLAGS.use_position_embeddings = False

    model = Ggnn(
        vocabulary=vocab,
        test_only=True,
        node_y_dimensionality=2,
        graph_y_dimensionality=0,
        graph_x_dimensionality=0,
        use_selector_embeddings=True,
    )
    checkpoint = pbutil.FromFile(checkpoint_path, checkpoint_pb2.Checkpoint())
    model.RestoreCheckpoint(checkpoint)

    batch = list(
        DataflowGgnnBatchBuilder(
            graph_loader=SingleGraphLoader(graph=graph, features=features),
            vocabulary=vocab,
            max_node_size=int(1e9),
            use_cdfg=FLAGS.cdfg,
            max_batch_count=1,
        )
    )[0]

    results = model.RunBatch(epoch_pb2.TEST, batch)

    return AnnotateGraphWithBatchResults(graph, features, results)