示例#1
0
def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--net", type=str, help="path to a .nnue net")
    parser.add_argument("--engine", type=str, help="path to stockfish")
    parser.add_argument("--data",
                        type=str,
                        help="path to a .bin or .binpack dataset")
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="Optional checkpoint (used instead of nnue for local eval)")
    parser.add_argument("--count",
                        type=int,
                        default=100,
                        help="number of datapoints to process")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    batch_size = 1000

    feature_set = features.get_feature_set_from_name(args.features)
    if args.checkpoint:
        model = NNUE.load_from_checkpoint(args.checkpoint,
                                          feature_set=feature_set)
    else:
        model = read_model(args.net, feature_set)
    model.eval()
    model.cuda()
    fen_batch_provider = make_fen_batch_provider(args.data, batch_size)

    model_evals = []
    engine_evals = []

    done = 0
    print('Processed {} positions.'.format(done))
    while done < args.count:
        fens = filter_fens(next(fen_batch_provider))

        b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens,
                                                     [0] * len(fens),
                                                     [1] * len(fens),
                                                     [0] * len(fens))
        model_evals += eval_model_batch(model, b)
        nnue_dataset.destroy_sparse_batch(b)

        engine_evals += eval_engine_batch(args.engine, args.net, fens)

        done += len(fens)
        print('Processed {} positions.'.format(done))

    compute_correlation(engine_evals, model_evals)
示例#2
0
 def commit_batch():
     nonlocal fens
     nonlocal results
     nonlocal scores
     nonlocal plies
     nonlocal model_evals
     nonlocal engine_evals
     if len(fens) == 0:
         return
     b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens, scores, plies, results)
     model_evals += eval_model_batch(model, b)
     nnue_dataset.destroy_sparse_batch(b)
     engine_evals += eval_engine_batch(args.engine, args.net, fens)
     fens = []
     results = []
     scores = []
     plies = []