コード例 #1
0
ファイル: train.py プロジェクト: kennyfrc/nnue-pytorch
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping_deprecated',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored"
    )
    parser.add_argument(
        "--no-smart-fen-skipping",
        action='store_true',
        dest='no_smart_fen_skipping',
        help=
        "If used then no smart fen skipping will be done. By default smart fen skipping is done."
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=3,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
        nnue.cuda()
    else:
        nnue = torch.load(args.resume_from_model)
        print("Resumed from model!")
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.cuda()

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 16384
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    tb_logger = pl_loggers.TensorBoardLogger(logdir)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=50,
        mode="min",
        monitor="val_loss",
        filename='{epoch}-{val_loss:.5f}',
        dirpath='logs')
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=tb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    print('Using c++ data loader')
    train, val = make_data_loaders(args.train, args.val, feature_set,
                                   args.num_workers, batch_size,
                                   not args.no_smart_fen_skipping,
                                   args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser(
        description="Visualizes networks in ckpt, pt and nnue format.")
    parser.add_argument("model",
                        help="Source model (can be .ckpt, .pt or .nnue)")
    parser.add_argument(
        "--ref-model",
        type=str,
        required=False,
        help=
        "Visualize the difference between the given reference model (can be .ckpt, .pt or .nnue)."
    )
    parser.add_argument(
        "--ref-features",
        type=str,
        required=False,
        help=
        "The reference feature set to use (default = same as source model).")
    parser.add_argument(
        "--input-weights-vmin",
        default=-1,
        type=float,
        help=
        "Minimum of color map range for input weights (absolute values are plotted if this is positive or zero)."
    )
    parser.add_argument("--input-weights-vmax",
                        default=1,
                        type=float,
                        help="Maximum of color map range for input weights.")
    parser.add_argument(
        "--input-weights-auto-scale",
        action="store_true",
        help=
        "Use auto-scale for the color map range for input weights. This ignores input-weights-vmin and input-weights-vmax."
    )
    parser.add_argument(
        "--input-weights-order",
        type=str,
        choices=["piece-centric-flipped-king", "king-centric"],
        default="piece-centric-flipped-king",
        help="Order of the input weights for each input neuron.")
    parser.add_argument(
        "--sort-input-neurons",
        action="store_true",
        help=
        "Sort the neurons of the input layer by the L1-norm (sum of absolute values) of their weights."
    )
    parser.add_argument(
        "--fc-weights-vmin",
        default=-2,
        type=float,
        help=
        "Minimum of color map range for fully-connected layer weights (absolute values are plotted if this is positive or zero)."
    )
    parser.add_argument(
        "--fc-weights-vmax",
        default=2,
        type=float,
        help="Maximum of color map range for fully-connected layer weights.")
    parser.add_argument(
        "--fc-weights-auto-scale",
        action="store_true",
        help=
        "Use auto-scale for the color map range for fully-connected layer weights. This ignores fc-weights-vmin and fc-weights-vmax."
    )
    parser.add_argument("--no-hist",
                        action="store_true",
                        help="Don't generate any histograms.")
    parser.add_argument("--no-biases",
                        action="store_true",
                        help="Don't generate plots for biases.")
    parser.add_argument(
        "--no-input-weights",
        action="store_true",
        help="Don't generate plots or histograms for input weights.")
    parser.add_argument(
        "--no-fc-weights",
        action="store_true",
        help=
        "Don't generate plots or histograms for fully-connected layer weights."
    )
    parser.add_argument("--default-width",
                        default=1600,
                        type=int,
                        help="Default width of all plots (in pixels).")
    parser.add_argument("--default-height",
                        default=900,
                        type=int,
                        help="Default height of all plots (in pixels).")
    parser.add_argument("--save-dir",
                        type=str,
                        required=False,
                        help="Save the plots in this directory.")
    parser.add_argument("--dont-show",
                        action="store_true",
                        help="Don't show the plots.")
    parser.add_argument(
        "--label",
        type=str,
        required=False,
        help=
        "Override the label used in plot titles and as prefix of saved files.")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    supported_features = ('HalfKP', 'HalfKP^')
    assert args.features in supported_features
    feature_set = features.get_feature_set_from_name(args.features)

    from os.path import basename
    label = basename(args.model)

    model = load_model(args.model, feature_set)

    if args.ref_model:
        if args.ref_features:
            assert args.ref_features in supported_features
            ref_feature_set = features.get_feature_set_from_name(
                args.ref_features)
        else:
            ref_feature_set = feature_set

        ref_model = load_model(args.ref_model, ref_feature_set)

        print("Visualizing difference between {} and {}".format(
            args.model, args.ref_model))

        from os.path import basename
        label = "diff " + label + "-" + basename(args.ref_model)
    else:
        ref_model = None
        print("Visualizing {}".format(args.model))

    if args.label is None:
        args.label = label

    visualizer = NNUEVisualizer(model, ref_model, args)

    visualizer.plot_input_weights()
    visualizer.plot_fc_weights()
    visualizer.plot_biases()

    if not args.dont_show:
        plt.show()
コード例 #3
0
ファイル: train.py プロジェクト: OfekShochat/nnue-pytorch
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument("--alpha",
                        default=1.0,
                        type=float,
                        dest='alpha_',
                        help="random multiply factor (default=1.0).")
    parser.add_argument(
        "--beta",
        default=6000,
        type=int,
        dest='beta_',
        help=
        "definite random step frequency - according to steps (default=6000).")
    parser.add_argument(
        "--gamma",
        default=0.0005,
        type=float,
        dest='gamma_',
        help="randomized random step frequency (default=0.0005).")
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set,
                      lambda_=args.lambda_,
                      alpha_=args.alpha_,
                      beta_=args.beta_,
                      gamma=args.gamma_)
    else:
        nnue = torch.load(args.resume_from_model)
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.alpha_ = args.alpha_
        nnue.beta_ = args.beta_
        nnue.gamma_ = args.gamma_

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    wandb_logger = WandbLogger()
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       period=5,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=wandb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    if args.py_data:
        print('Using python data loader')
        train, val = data_loader_py(args.train, args.val, feature_set,
                                    batch_size, main_device)
    else:
        print('Using c++ data loader')
        train, val = data_loader_cc(args.train, args.val, feature_set,
                                    args.num_workers, batch_size,
                                    args.smart_fen_skipping,
                                    args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
コード例 #4
0
ファイル: train.py プロジェクト: Xcrid/nnue-pytorch
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")

    parser.add_argument("--tune",
                        action="store_true",
                        help="automated LR search")
    parser.add_argument(
        "--save",
        action="store_true",
        help="save after every training epoch (default = False)")
    parser.add_argument("--experiment",
                        default="1",
                        type=str,
                        help="specify the experiment id")
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")

    features.add_argparse_args(parser)
    args = parser.parse_args()

    print("Training with {} validating with {}".format(args.train, args.val))

    torch.manual_seed(123)
    torch.cuda.manual_seed(123)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    feature_set = features.get_feature_set_from_name(args.features)

    if args.py_data:
        print('Using python data loader')
        train_data, val_data = data_loader_py(args.train, args.val, batch_size,
                                              feature_set, 'cuda:0')

    else:
        print('Using c++ data loader')
        train_data, val_data = data_loader_cc(
            args.train, args.val, feature_set, args.num_workers, batch_size,
            args.smart_fen_skipping, args.random_fen_skipping, 'cuda:0')

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    START_EPOCH = 0
    NUM_EPOCHS = 150
    SWA_START = int(0.75 * NUM_EPOCHS)

    LEARNING_RATE = 5e-4
    DECAY = 0
    EPS = 1e-7

    best_loss = 1000
    is_best = False

    early_stopping_delay = 30
    early_stopping_count = 0
    early_stopping_flag = False

    summary_location = 'logs/nnue_experiment_' + args.experiment
    save_location = '/home/esigelec/PycharmProjects/nnue-pytorch/save_models/' + args.experiment

    writer = SummaryWriter(summary_location)

    nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, s=1)

    train_params = [{
        'params': nnue.get_1xlr(),
        'lr': LEARNING_RATE
    }, {
        'params': nnue.get_10xlr(),
        'lr': LEARNING_RATE * 10.0
    }]

    optimizer = ranger.Ranger(train_params,
                              lr=LEARNING_RATE,
                              eps=EPS,
                              betas=(0.9, 0.999),
                              weight_decay=DECAY)

    if args.resume_from_model is not None:
        nnue, optimizer, START_EPOCH = load_ckp(args.resume_from_model, nnue,
                                                optimizer)
        nnue.set_feature_set(feature_set)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=7,
                                                           cooldown=1,
                                                           min_lr=1e-7,
                                                           verbose=True)
    swa_scheduler = SWALR(optimizer, annealing_epochs=5, swa_lr=[5e-5, 1e-4])

    nnue = nnue.cuda()
    swa_nnue = AveragedModel(nnue)

    for epoch in range(START_EPOCH, NUM_EPOCHS):

        nnue.train()

        train_interval = 100
        loss_f_sum_interval = 0.0
        loss_f_sum_epoch = 0.0
        loss_v_sum_epoch = 0.0

        if early_stopping_flag:
            print("early end of training at epoch" + str(epoch))
            break

        for batch_idx, batch in enumerate(train_data):

            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            optimizer.zero_grad()
            output = nnue(us, them, white, black)

            loss = nnue_loss(output, outcome, score, args.lambda_)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(nnue.parameters(), 0.5)
            optimizer.step()

            loss_f_sum_interval += loss.float()
            loss_f_sum_epoch += loss.float()

            if batch_idx % train_interval == train_interval - 1:

                writer.add_scalar('train_loss',
                                  loss_f_sum_interval / train_interval,
                                  epoch * len(train_data) + batch_idx)

                loss_f_sum_interval = 0.0

        print("Epoch #{}\t Train_Loss: {:.8f}\t".format(
            epoch, loss_f_sum_epoch / len(train_data)))

        if epoch % 1 == 0 or (epoch + 1) == NUM_EPOCHS:

            with torch.no_grad():
                nnue.eval()
                for batch_idx, batch in enumerate(val_data):
                    batch = [_data.cuda() for _data in batch]
                    us, them, white, black, outcome, score = batch

                    _output = nnue(us, them, white, black)
                    loss_v = nnue_loss(_output, outcome, score, args.lambda_)
                    loss_v_sum_epoch += loss_v.float()

            if epoch > SWA_START:
                print("swa_mode")
                swa_nnue.update_parameters(nnue)
                swa_scheduler.step()
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': swa_nnue.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                save_ckp(checkpoint, save_location, 'swa_nnue.pt')

            else:

                scheduler.step(loss_v_sum_epoch / len(val_data))

                if loss_v_sum_epoch / len(val_data) <= best_loss:
                    best_loss = loss_v_sum_epoch / len(val_data)
                    is_best = True
                    early_stopping_count = 0
                else:
                    early_stopping_count += 1
                if early_stopping_delay == early_stopping_count:
                    early_stopping_flag = True

                if is_best:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'state_dict': nnue.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    save_ckp(checkpoint, save_location)
                    is_best = False

            writer.add_scalar('val_loss', loss_v_sum_epoch / len(val_data),
                              epoch * len(train_data) + batch_idx)

            print("Epoch #{}\tVal_Loss: {:.8f}\t".format(
                epoch, loss_v_sum_epoch / len(val_data)))

    loss_v_sum_epoch = 0.0

    with torch.no_grad():
        swa_nnue.eval()
        for batch_idx, batch in enumerate(val_data):
            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            _output = swa_nnue(us, them, white, black)
            loss_v = nnue_loss(_output, outcome, score, args.lambda_)
            loss_v_sum_epoch += loss_v.float()

    print("Val_Loss: {:.8f}\t".format(loss_v_sum_epoch / len(val_data)))

    writer.close()