Example #1
0
def main():
    parser = argparse.ArgumentParser(
        description="Converts files between ckpt and nnue format.")
    parser.add_argument("source",
                        help="Source file (can be .ckpt, .pt or .nnue)")
    parser.add_argument("target", help="Target file (can be .pt or .nnue)")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    feature_set = features.get_feature_set_from_name(args.features)

    print('Converting %s to %s' % (args.source, args.target))

    if args.source.endswith(".pt") or args.source.endswith(".ckpt"):
        if not args.target.endswith(".nnue"):
            raise Exception("Target file must end with .nnue")
        if args.source.endswith(".pt"):
            nnue = torch.load(args.source)
        else:
            nnue = M.NNUE.load_from_checkpoint(args.source,
                                               feature_set=feature_set)
        nnue.eval()
        writer = NNUEWriter(nnue)
        with open(args.target, 'wb') as f:
            f.write(writer.buf)
    elif args.source.endswith(".nnue"):
        if not args.target.endswith(".pt"):
            raise Exception("Target file must end with .pt")
        with open(args.source, 'rb') as f:
            reader = NNUEReader(f, feature_set)
        torch.save(reader.model, args.target)
    else:
        raise Exception('Invalid filetypes: ' + str(args))
Example #2
0
def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--net", type=str, help="path to a .nnue net")
    parser.add_argument("--engine", type=str, help="path to stockfish")
    parser.add_argument("--data",
                        type=str,
                        help="path to a .bin or .binpack dataset")
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="Optional checkpoint (used instead of nnue for local eval)")
    parser.add_argument("--count",
                        type=int,
                        default=100,
                        help="number of datapoints to process")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    batch_size = 1000

    feature_set = features.get_feature_set_from_name(args.features)
    if args.checkpoint:
        model = NNUE.load_from_checkpoint(args.checkpoint,
                                          feature_set=feature_set)
    else:
        model = read_model(args.net, feature_set)
    model.eval()
    model.cuda()
    fen_batch_provider = make_fen_batch_provider(args.data, batch_size)

    model_evals = []
    engine_evals = []

    done = 0
    print('Processed {} positions.'.format(done))
    while done < args.count:
        fens = filter_fens(next(fen_batch_provider))

        b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens,
                                                     [0] * len(fens),
                                                     [1] * len(fens),
                                                     [0] * len(fens))
        model_evals += eval_model_batch(model, b)
        nnue_dataset.destroy_sparse_batch(b)

        engine_evals += eval_engine_batch(args.engine, args.net, fens)

        done += len(fens)
        print('Processed {} positions.'.format(done))

    compute_correlation(engine_evals, model_evals)
Example #3
0
def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--net", type=str, help="path to a .nnue net")
    parser.add_argument("--engine", type=str, help="path to stockfish")
    parser.add_argument("--data", type=str, help="path to .bin dataset")
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="Optional checkpoint (used instead of nnue for local eval)")
    parser.add_argument("--count",
                        type=int,
                        default=100,
                        help="number of datapoints to process")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    feature_set = features.get_feature_set_from_name(args.features)
    if args.checkpoint:
        model = NNUE.load_from_checkpoint(args.checkpoint,
                                          feature_set=feature_set)
    else:
        model = read_model(args.net, feature_set)
    model.eval()
    data_reader = make_data_reader(args.data, feature_set)

    fens = []
    model_evals = []
    i = -1
    done = 0
    while done < args.count:
        i += 1

        item = data_reader.get_raw(i)
        board = item[0]
        if board.is_check():
            continue

        fen = board.fen()
        fens.append(fen)
        eval = eval_model(model, data_reader.transform(item))
        model_evals.append(eval)

        done += 1

    engine_evals = eval_engine_batch(args.engine, args.net, fens)
    compute_correlation(engine_evals, model_evals)
Example #4
0
def main():
    parser = argparse.ArgumentParser(
        description="Converts files between ckpt and nnue format.")
    parser.add_argument("source",
                        help="Source file (can be .ckpt, .pt or .nnue)")
    parser.add_argument("target", help="Target file (can be .pt or .nnue)")
    parser.add_argument(
        "--description",
        default=None,
        type=str,
        dest='description',
        help=
        "The description string to include in the network. Only works when serializing into a .nnue file."
    )
    features.add_argparse_args(parser)
    args = parser.parse_args()

    feature_set = features.get_feature_set_from_name(args.features)

    print('Converting %s to %s' % (args.source, args.target))

    if args.source.endswith('.ckpt'):
        nnue = M.NNUE.load_from_checkpoint(args.source,
                                           feature_set=feature_set)
        nnue.eval()
    elif args.source.endswith('.pt'):
        nnue = torch.load(args.source)
    elif args.source.endswith('.nnue'):
        with open(args.source, 'rb') as f:
            reader = NNUEReader(f, feature_set)
            nnue = reader.model
    else:
        raise Exception('Invalid network input format.')

    if args.target.endswith('.ckpt'):
        raise Exception('Cannot convert into .ckpt')
    elif args.target.endswith('.pt'):
        torch.save(nnue, args.target)
    elif args.target.endswith('.nnue'):
        writer = NNUEWriter(nnue, args.description)
        with open(args.target, 'wb') as f:
            f.write(writer.buf)
    else:
        raise Exception('Invalid network output format.')
Example #5
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping_deprecated',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored"
    )
    parser.add_argument(
        "--no-smart-fen-skipping",
        action='store_true',
        dest='no_smart_fen_skipping',
        help=
        "If used then no smart fen skipping will be done. By default smart fen skipping is done."
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=3,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
        nnue.cuda()
    else:
        nnue = torch.load(args.resume_from_model)
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.cuda()

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 16384
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    tb_logger = pl_loggers.TensorBoardLogger(logdir)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       period=20,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=tb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    print('Using c++ data loader')
    train, val = make_data_loaders(args.train, args.val, feature_set,
                                   args.num_workers, batch_size,
                                   not args.no_smart_fen_skipping,
                                   args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
Example #6
0
def main():
    parser = argparse.ArgumentParser(
        description="Visualizes networks in ckpt, pt and nnue format.")
    parser.add_argument("model",
                        help="Source model (can be .ckpt, .pt or .nnue)")
    parser.add_argument(
        "--ref-model",
        type=str,
        required=False,
        help=
        "Visualize the difference between the given reference model (can be .ckpt, .pt or .nnue)."
    )
    parser.add_argument(
        "--ref-features",
        type=str,
        required=False,
        help=
        "The reference feature set to use (default = same as source model).")
    parser.add_argument(
        "--input-weights-vmin",
        default=-1,
        type=float,
        help=
        "Minimum of color map range for input weights (absolute values are plotted if this is positive or zero)."
    )
    parser.add_argument("--input-weights-vmax",
                        default=1,
                        type=float,
                        help="Maximum of color map range for input weights.")
    parser.add_argument(
        "--input-weights-auto-scale",
        action="store_true",
        help=
        "Use auto-scale for the color map range for input weights. This ignores input-weights-vmin and input-weights-vmax."
    )
    parser.add_argument(
        "--input-weights-order",
        type=str,
        choices=["piece-centric-flipped-king", "king-centric"],
        default="piece-centric-flipped-king",
        help="Order of the input weights for each input neuron.")
    parser.add_argument(
        "--sort-input-neurons",
        action="store_true",
        help=
        "Sort the neurons of the input layer by the L1-norm (sum of absolute values) of their weights."
    )
    parser.add_argument(
        "--fc-weights-vmin",
        default=-2,
        type=float,
        help=
        "Minimum of color map range for fully-connected layer weights (absolute values are plotted if this is positive or zero)."
    )
    parser.add_argument(
        "--fc-weights-vmax",
        default=2,
        type=float,
        help="Maximum of color map range for fully-connected layer weights.")
    parser.add_argument(
        "--fc-weights-auto-scale",
        action="store_true",
        help=
        "Use auto-scale for the color map range for fully-connected layer weights. This ignores fc-weights-vmin and fc-weights-vmax."
    )
    parser.add_argument("--no-hist",
                        action="store_true",
                        help="Don't generate any histograms.")
    parser.add_argument("--no-biases",
                        action="store_true",
                        help="Don't generate plots for biases.")
    parser.add_argument(
        "--no-input-weights",
        action="store_true",
        help="Don't generate plots or histograms for input weights.")
    parser.add_argument(
        "--no-fc-weights",
        action="store_true",
        help=
        "Don't generate plots or histograms for fully-connected layer weights."
    )
    parser.add_argument("--default-width",
                        default=1600,
                        type=int,
                        help="Default width of all plots (in pixels).")
    parser.add_argument("--default-height",
                        default=900,
                        type=int,
                        help="Default height of all plots (in pixels).")
    parser.add_argument("--save-dir",
                        type=str,
                        required=False,
                        help="Save the plots in this directory.")
    parser.add_argument("--dont-show",
                        action="store_true",
                        help="Don't show the plots.")
    parser.add_argument(
        "--label",
        type=str,
        required=False,
        help=
        "Override the label used in plot titles and as prefix of saved files.")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    supported_features = ('HalfKAv2', 'HalfKAv2^')
    assert args.features in supported_features
    feature_set = features.get_feature_set_from_name(args.features)

    from os.path import basename
    label = basename(args.model)

    model = load_model(args.model, feature_set)

    if args.ref_model:
        if args.ref_features:
            assert args.ref_features in supported_features
            ref_feature_set = features.get_feature_set_from_name(
                args.ref_features)
        else:
            ref_feature_set = feature_set

        ref_model = load_model(args.ref_model, ref_feature_set)

        print("Visualizing difference between {} and {}".format(
            args.model, args.ref_model))

        from os.path import basename
        label = "diff " + label + "-" + basename(args.ref_model)
    else:
        ref_model = None
        print("Visualizing {}".format(args.model))

    if args.label is None:
        args.label = label

    visualizer = NNUEVisualizer(model, ref_model, args)

    visualizer.plot_input_weights()
    #visualizer.plot_fc_weights()
    #visualizer.plot_biases()

    if not args.dont_show:
        plt.show()
Example #7
0
def main():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--net", type=str, help="path to a .nnue net")
    parser.add_argument("--engine", type=str, help="path to stockfish")
    parser.add_argument("--data", type=str, help="path to .bin dataset")
    parser.add_argument("--checkpoint", type=str, help="Optional checkpoint (used instead of nnue for local eval)")
    parser.add_argument("--count", type=int, default=100, help="number of datapoints to process")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    feature_set = features.get_feature_set_from_name(args.features)
    if args.checkpoint:
      model = NNUE.load_from_checkpoint(args.checkpoint, feature_set=feature_set)
    else:
      model = read_model(args.net, feature_set)
    model.eval()
    data_reader = make_data_reader(args.data, feature_set)

    fens = []
    results = []
    scores = []
    plies = []
    model_evals = []
    engine_evals = []
    i = -1

    def commit_batch():
        nonlocal fens
        nonlocal results
        nonlocal scores
        nonlocal plies
        nonlocal model_evals
        nonlocal engine_evals
        if len(fens) == 0:
            return
        b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens, scores, plies, results)
        model_evals += eval_model_batch(model, b)
        nnue_dataset.destroy_sparse_batch(b)
        engine_evals += eval_engine_batch(args.engine, args.net, fens)
        fens = []
        results = []
        scores = []
        plies = []

    done = 0
    while done < args.count:
        i += 1

        item = data_reader.get_raw(i)
        board = item[0]
        if board.is_check():
            continue

        fens.append(board.fen())
        results.append(int(round(item[2] * 2 - 1)))
        scores.append(int(item[3]))
        plies.append(1)

        done += 1

        if done % 1024 == 0:
            # don't do batches that are too big
            commit_batch()

    commit_batch()

    compute_correlation(engine_evals, model_evals)
def main():
    parser = argparse.ArgumentParser(
        description="Visualizes networks in ckpt, pt and nnue format.")
    parser.add_argument("models",
                        nargs='+',
                        help="Source model (can be .ckpt, .pt or .nnue)")
    parser.add_argument("--dont-show",
                        action="store_true",
                        help="Don't show the plots.")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    supported_features = ('HalfKAv2', 'HalfKAv2^')
    assert args.features in supported_features
    feature_set = features.get_feature_set_from_name(args.features)

    from os.path import basename
    labels = []
    for m in args.models:
        label = basename(m)
        if label.startswith('nn-'):
            label = label[3:]
        if label.endswith('.nnue'):
            label = label[:-5]
        labels.append('\n'.join(label.split('-')))

    models = [load_model(m, feature_set) for m in args.models]

    coalesced_ins = [
        M.coalesce_ft_weights(model, model.input) for model in models
    ]
    input_weights = [
        coalesced_in[:, :M.L1].flatten().numpy()
        for coalesced_in in coalesced_ins
    ]
    input_weights_psqt = [(coalesced_in[:, M.L1:] * 600).flatten().numpy()
                          for coalesced_in in coalesced_ins]
    plot_hists(
        [input_weights],
        labels, [None],
        w=10.0,
        h=3.0,
        num_bins=8 * 128,
        title=
        'Distribution of feature transformer weights among different nets',
        filename='input_weights_hist.png')
    plot_hists(
        [input_weights_psqt],
        labels, [None],
        w=10.0,
        h=3.0,
        num_bins=8 * 128,
        title=
        'Distribution of feature transformer PSQT weights among different nets (in stockfish internal units)',
        filename='input_weights_psqt_hist.png')

    layer_stacks = [model.layer_stacks for model in models]
    layers_l1 = [[] for i in range(layer_stacks[0].count)]
    layers_l2 = [[] for i in range(layer_stacks[0].count)]
    layers_l3 = [[] for i in range(layer_stacks[0].count)]
    for ls in layer_stacks:
        for i, sublayers in enumerate(ls.get_coalesced_layer_stacks()):
            l1, l2, l3 = sublayers
            layers_l1[i].append(l1.weight.flatten().numpy())
            layers_l2[i].append(l2.weight.flatten().numpy())
            layers_l3[i].append(l3.weight.flatten().numpy())
    col_names = ['Subnet {}'.format(i) for i in range(layer_stacks[0].count)]
    plot_hists(
        layers_l1,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=128,
        title='Distribution of l1 weights among different nets and buckets',
        filename='l1_weights_hist.png')
    plot_hists(
        layers_l2,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=32,
        title='Distribution of l2 weights among different nets and buckets',
        filename='l2_weights_hist.png')
    plot_hists(
        layers_l3,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=16,
        title='Distribution of output weights among different nets and buckets',
        filename='output_weights_hist.png')

    if not args.dont_show:
        plt.show()
Example #9
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument("--alpha",
                        default=1.0,
                        type=float,
                        dest='alpha_',
                        help="random multiply factor (default=1.0).")
    parser.add_argument(
        "--beta",
        default=6000,
        type=int,
        dest='beta_',
        help=
        "definite random step frequency - according to steps (default=6000).")
    parser.add_argument(
        "--gamma",
        default=0.0005,
        type=float,
        dest='gamma_',
        help="randomized random step frequency (default=0.0005).")
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set,
                      lambda_=args.lambda_,
                      alpha_=args.alpha_,
                      beta_=args.beta_,
                      gamma=args.gamma_)
    else:
        nnue = torch.load(args.resume_from_model)
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.alpha_ = args.alpha_
        nnue.beta_ = args.beta_
        nnue.gamma_ = args.gamma_

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    wandb_logger = WandbLogger()
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       period=5,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=wandb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    if args.py_data:
        print('Using python data loader')
        train, val = data_loader_py(args.train, args.val, feature_set,
                                    batch_size, main_device)
    else:
        print('Using c++ data loader')
        train, val = data_loader_cc(args.train, args.val, feature_set,
                                    args.num_workers, batch_size,
                                    args.smart_fen_skipping,
                                    args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
Example #10
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")

    parser.add_argument("--tune",
                        action="store_true",
                        help="automated LR search")
    parser.add_argument(
        "--save",
        action="store_true",
        help="save after every training epoch (default = False)")
    parser.add_argument("--experiment",
                        default="1",
                        type=str,
                        help="specify the experiment id")
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")

    features.add_argparse_args(parser)
    args = parser.parse_args()

    print("Training with {} validating with {}".format(args.train, args.val))

    torch.manual_seed(123)
    torch.cuda.manual_seed(123)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    feature_set = features.get_feature_set_from_name(args.features)

    if args.py_data:
        print('Using python data loader')
        train_data, val_data = data_loader_py(args.train, args.val, batch_size,
                                              feature_set, 'cuda:0')

    else:
        print('Using c++ data loader')
        train_data, val_data = data_loader_cc(
            args.train, args.val, feature_set, args.num_workers, batch_size,
            args.smart_fen_skipping, args.random_fen_skipping, 'cuda:0')

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    START_EPOCH = 0
    NUM_EPOCHS = 150
    SWA_START = int(0.75 * NUM_EPOCHS)

    LEARNING_RATE = 5e-4
    DECAY = 0
    EPS = 1e-7

    best_loss = 1000
    is_best = False

    early_stopping_delay = 30
    early_stopping_count = 0
    early_stopping_flag = False

    summary_location = 'logs/nnue_experiment_' + args.experiment
    save_location = '/home/esigelec/PycharmProjects/nnue-pytorch/save_models/' + args.experiment

    writer = SummaryWriter(summary_location)

    nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, s=1)

    train_params = [{
        'params': nnue.get_1xlr(),
        'lr': LEARNING_RATE
    }, {
        'params': nnue.get_10xlr(),
        'lr': LEARNING_RATE * 10.0
    }]

    optimizer = ranger.Ranger(train_params,
                              lr=LEARNING_RATE,
                              eps=EPS,
                              betas=(0.9, 0.999),
                              weight_decay=DECAY)

    if args.resume_from_model is not None:
        nnue, optimizer, START_EPOCH = load_ckp(args.resume_from_model, nnue,
                                                optimizer)
        nnue.set_feature_set(feature_set)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=7,
                                                           cooldown=1,
                                                           min_lr=1e-7,
                                                           verbose=True)
    swa_scheduler = SWALR(optimizer, annealing_epochs=5, swa_lr=[5e-5, 1e-4])

    nnue = nnue.cuda()
    swa_nnue = AveragedModel(nnue)

    for epoch in range(START_EPOCH, NUM_EPOCHS):

        nnue.train()

        train_interval = 100
        loss_f_sum_interval = 0.0
        loss_f_sum_epoch = 0.0
        loss_v_sum_epoch = 0.0

        if early_stopping_flag:
            print("early end of training at epoch" + str(epoch))
            break

        for batch_idx, batch in enumerate(train_data):

            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            optimizer.zero_grad()
            output = nnue(us, them, white, black)

            loss = nnue_loss(output, outcome, score, args.lambda_)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(nnue.parameters(), 0.5)
            optimizer.step()

            loss_f_sum_interval += loss.float()
            loss_f_sum_epoch += loss.float()

            if batch_idx % train_interval == train_interval - 1:

                writer.add_scalar('train_loss',
                                  loss_f_sum_interval / train_interval,
                                  epoch * len(train_data) + batch_idx)

                loss_f_sum_interval = 0.0

        print("Epoch #{}\t Train_Loss: {:.8f}\t".format(
            epoch, loss_f_sum_epoch / len(train_data)))

        if epoch % 1 == 0 or (epoch + 1) == NUM_EPOCHS:

            with torch.no_grad():
                nnue.eval()
                for batch_idx, batch in enumerate(val_data):
                    batch = [_data.cuda() for _data in batch]
                    us, them, white, black, outcome, score = batch

                    _output = nnue(us, them, white, black)
                    loss_v = nnue_loss(_output, outcome, score, args.lambda_)
                    loss_v_sum_epoch += loss_v.float()

            if epoch > SWA_START:
                print("swa_mode")
                swa_nnue.update_parameters(nnue)
                swa_scheduler.step()
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': swa_nnue.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                save_ckp(checkpoint, save_location, 'swa_nnue.pt')

            else:

                scheduler.step(loss_v_sum_epoch / len(val_data))

                if loss_v_sum_epoch / len(val_data) <= best_loss:
                    best_loss = loss_v_sum_epoch / len(val_data)
                    is_best = True
                    early_stopping_count = 0
                else:
                    early_stopping_count += 1
                if early_stopping_delay == early_stopping_count:
                    early_stopping_flag = True

                if is_best:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'state_dict': nnue.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    save_ckp(checkpoint, save_location)
                    is_best = False

            writer.add_scalar('val_loss', loss_v_sum_epoch / len(val_data),
                              epoch * len(train_data) + batch_idx)

            print("Epoch #{}\tVal_Loss: {:.8f}\t".format(
                epoch, loss_v_sum_epoch / len(val_data)))

    loss_v_sum_epoch = 0.0

    with torch.no_grad():
        swa_nnue.eval()
        for batch_idx, batch in enumerate(val_data):
            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            _output = swa_nnue(us, them, white, black)
            loss_v = nnue_loss(_output, outcome, score, args.lambda_)
            loss_v_sum_epoch += loss_v.float()

    print("Val_Loss: {:.8f}\t".format(loss_v_sum_epoch / len(val_data)))

    writer.close()