def main(): parser = argparse.ArgumentParser(description="Trains the network.") parser.add_argument("train", help="Training data (.bin or .binpack)") parser.add_argument("val", help="Validation data (.bin or .binpack)") parser = pl.Trainer.add_argparse_args(parser) parser.add_argument( "--lambda", default=1.0, type=float, dest='lambda_', help= "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)." ) parser.add_argument( "--num-workers", default=1, type=int, dest='num_workers', help= "Number of worker threads to use for data loading. Currently only works well for binpack." ) parser.add_argument( "--batch-size", default=-1, type=int, dest='batch_size', help= "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128." ) parser.add_argument( "--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .") parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.") parser.add_argument( "--smart-fen-skipping", action='store_true', dest='smart_fen_skipping_deprecated', help= "If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored" ) parser.add_argument( "--no-smart-fen-skipping", action='store_true', dest='no_smart_fen_skipping', help= "If used then no smart fen skipping will be done. By default smart fen skipping is done." ) parser.add_argument( "--random-fen-skipping", default=3, type=int, dest='random_fen_skipping', help= "skip fens randomly on average random_fen_skipping before using one.") parser.add_argument( "--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model") features.add_argparse_args(parser) args = parser.parse_args() if not os.path.exists(args.train): raise Exception('{0} does not exist'.format(args.train)) if not os.path.exists(args.val): raise Exception('{0} does not exist'.format(args.val)) feature_set = features.get_feature_set_from_name(args.features) if args.resume_from_model is None: nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_) nnue.cuda() else: nnue = torch.load(args.resume_from_model) print("Resumed from model!") nnue.set_feature_set(feature_set) nnue.lambda_ = args.lambda_ nnue.cuda() print("Feature set: {}".format(feature_set.name)) print("Num real features: {}".format(feature_set.num_real_features)) print("Num virtual features: {}".format(feature_set.num_virtual_features)) print("Num features: {}".format(feature_set.num_features)) print("Training with {} validating with {}".format(args.train, args.val)) pl.seed_everything(args.seed) print("Seed {}".format(args.seed)) batch_size = args.batch_size if batch_size <= 0: batch_size = 16384 print('Using batch size {}'.format(batch_size)) print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping)) print('Random fen skipping: {}'.format(args.random_fen_skipping)) if args.threads > 0: print('limiting torch to {} threads.'.format(args.threads)) t_set_num_threads(args.threads) logdir = args.default_root_dir if args.default_root_dir else 'logs/' print('Using log dir {}'.format(logdir), flush=True) tb_logger = pl_loggers.TensorBoardLogger(logdir) checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=50, mode="min", monitor="val_loss", filename='{epoch}-{val_loss:.5f}', dirpath='logs') trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger) main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str( trainer.root_gpu) print('Using c++ data loader') train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, main_device) trainer.fit(nnue, train, val)
def main(): parser = argparse.ArgumentParser( description="Visualizes networks in ckpt, pt and nnue format.") parser.add_argument("model", help="Source model (can be .ckpt, .pt or .nnue)") parser.add_argument( "--ref-model", type=str, required=False, help= "Visualize the difference between the given reference model (can be .ckpt, .pt or .nnue)." ) parser.add_argument( "--ref-features", type=str, required=False, help= "The reference feature set to use (default = same as source model).") parser.add_argument( "--input-weights-vmin", default=-1, type=float, help= "Minimum of color map range for input weights (absolute values are plotted if this is positive or zero)." ) parser.add_argument("--input-weights-vmax", default=1, type=float, help="Maximum of color map range for input weights.") parser.add_argument( "--input-weights-auto-scale", action="store_true", help= "Use auto-scale for the color map range for input weights. This ignores input-weights-vmin and input-weights-vmax." ) parser.add_argument( "--input-weights-order", type=str, choices=["piece-centric-flipped-king", "king-centric"], default="piece-centric-flipped-king", help="Order of the input weights for each input neuron.") parser.add_argument( "--sort-input-neurons", action="store_true", help= "Sort the neurons of the input layer by the L1-norm (sum of absolute values) of their weights." ) parser.add_argument( "--fc-weights-vmin", default=-2, type=float, help= "Minimum of color map range for fully-connected layer weights (absolute values are plotted if this is positive or zero)." ) parser.add_argument( "--fc-weights-vmax", default=2, type=float, help="Maximum of color map range for fully-connected layer weights.") parser.add_argument( "--fc-weights-auto-scale", action="store_true", help= "Use auto-scale for the color map range for fully-connected layer weights. This ignores fc-weights-vmin and fc-weights-vmax." ) parser.add_argument("--no-hist", action="store_true", help="Don't generate any histograms.") parser.add_argument("--no-biases", action="store_true", help="Don't generate plots for biases.") parser.add_argument( "--no-input-weights", action="store_true", help="Don't generate plots or histograms for input weights.") parser.add_argument( "--no-fc-weights", action="store_true", help= "Don't generate plots or histograms for fully-connected layer weights." ) parser.add_argument("--default-width", default=1600, type=int, help="Default width of all plots (in pixels).") parser.add_argument("--default-height", default=900, type=int, help="Default height of all plots (in pixels).") parser.add_argument("--save-dir", type=str, required=False, help="Save the plots in this directory.") parser.add_argument("--dont-show", action="store_true", help="Don't show the plots.") parser.add_argument( "--label", type=str, required=False, help= "Override the label used in plot titles and as prefix of saved files.") features.add_argparse_args(parser) args = parser.parse_args() supported_features = ('HalfKP', 'HalfKP^') assert args.features in supported_features feature_set = features.get_feature_set_from_name(args.features) from os.path import basename label = basename(args.model) model = load_model(args.model, feature_set) if args.ref_model: if args.ref_features: assert args.ref_features in supported_features ref_feature_set = features.get_feature_set_from_name( args.ref_features) else: ref_feature_set = feature_set ref_model = load_model(args.ref_model, ref_feature_set) print("Visualizing difference between {} and {}".format( args.model, args.ref_model)) from os.path import basename label = "diff " + label + "-" + basename(args.ref_model) else: ref_model = None print("Visualizing {}".format(args.model)) if args.label is None: args.label = label visualizer = NNUEVisualizer(model, ref_model, args) visualizer.plot_input_weights() visualizer.plot_fc_weights() visualizer.plot_biases() if not args.dont_show: plt.show()
def main(): parser = argparse.ArgumentParser(description="Trains the network.") parser.add_argument("train", help="Training data (.bin or .binpack)") parser.add_argument("val", help="Validation data (.bin or .binpack)") parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--py-data", action="store_true", help="Use python data loader (default=False)") parser.add_argument( "--lambda", default=1.0, type=float, dest='lambda_', help= "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)." ) parser.add_argument("--alpha", default=1.0, type=float, dest='alpha_', help="random multiply factor (default=1.0).") parser.add_argument( "--beta", default=6000, type=int, dest='beta_', help= "definite random step frequency - according to steps (default=6000).") parser.add_argument( "--gamma", default=0.0005, type=float, dest='gamma_', help="randomized random step frequency (default=0.0005).") parser.add_argument( "--num-workers", default=1, type=int, dest='num_workers', help= "Number of worker threads to use for data loading. Currently only works well for binpack." ) parser.add_argument( "--batch-size", default=-1, type=int, dest='batch_size', help= "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128." ) parser.add_argument( "--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .") parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.") parser.add_argument( "--smart-fen-skipping", action='store_true', dest='smart_fen_skipping', help= "If enabled positions that are bad training targets will be skipped during loading. Default: False" ) parser.add_argument( "--random-fen-skipping", default=0, type=int, dest='random_fen_skipping', help= "skip fens randomly on average random_fen_skipping before using one.") parser.add_argument( "--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model") features.add_argparse_args(parser) args = parser.parse_args() if not os.path.exists(args.train): raise Exception('{0} does not exist'.format(args.train)) if not os.path.exists(args.val): raise Exception('{0} does not exist'.format(args.val)) feature_set = features.get_feature_set_from_name(args.features) if args.resume_from_model is None: nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, alpha_=args.alpha_, beta_=args.beta_, gamma=args.gamma_) else: nnue = torch.load(args.resume_from_model) nnue.set_feature_set(feature_set) nnue.lambda_ = args.lambda_ nnue.alpha_ = args.alpha_ nnue.beta_ = args.beta_ nnue.gamma_ = args.gamma_ print("Feature set: {}".format(feature_set.name)) print("Num real features: {}".format(feature_set.num_real_features)) print("Num virtual features: {}".format(feature_set.num_virtual_features)) print("Num features: {}".format(feature_set.num_features)) print("Training with {} validating with {}".format(args.train, args.val)) pl.seed_everything(args.seed) print("Seed {}".format(args.seed)) batch_size = args.batch_size if batch_size <= 0: batch_size = 128 if args.gpus == 0 else 8192 print('Using batch size {}'.format(batch_size)) print('Smart fen skipping: {}'.format(args.smart_fen_skipping)) print('Random fen skipping: {}'.format(args.random_fen_skipping)) if args.threads > 0: print('limiting torch to {} threads.'.format(args.threads)) t_set_num_threads(args.threads) logdir = args.default_root_dir if args.default_root_dir else 'logs/' print('Using log dir {}'.format(logdir), flush=True) wandb_logger = WandbLogger() checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True, period=5, save_top_k=-1) trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=wandb_logger) main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str( trainer.root_gpu) if args.py_data: print('Using python data loader') train, val = data_loader_py(args.train, args.val, feature_set, batch_size, main_device) else: print('Using c++ data loader') train, val = data_loader_cc(args.train, args.val, feature_set, args.num_workers, batch_size, args.smart_fen_skipping, args.random_fen_skipping, main_device) trainer.fit(nnue, train, val)
def main(): parser = argparse.ArgumentParser(description="Trains the network.") parser.add_argument("train", help="Training data (.bin or .binpack)") parser.add_argument("val", help="Validation data (.bin or .binpack)") parser.add_argument("--tune", action="store_true", help="automated LR search") parser.add_argument( "--save", action="store_true", help="save after every training epoch (default = False)") parser.add_argument("--experiment", default="1", type=str, help="specify the experiment id") parser.add_argument("--py-data", action="store_true", help="Use python data loader (default=False)") parser.add_argument( "--lambda", default=1.0, type=float, dest='lambda_', help= "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)." ) parser.add_argument( "--num-workers", default=1, type=int, dest='num_workers', help= "Number of worker threads to use for data loading. Currently only works well for binpack." ) parser.add_argument( "--batch-size", default=-1, type=int, dest='batch_size', help= "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128." ) parser.add_argument( "--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .") parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.") parser.add_argument( "--smart-fen-skipping", action='store_true', dest='smart_fen_skipping', help= "If enabled positions that are bad training targets will be skipped during loading. Default: False" ) parser.add_argument( "--random-fen-skipping", default=0, type=int, dest='random_fen_skipping', help= "skip fens randomly on average random_fen_skipping before using one.") parser.add_argument( "--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model") features.add_argparse_args(parser) args = parser.parse_args() print("Training with {} validating with {}".format(args.train, args.val)) torch.manual_seed(123) torch.cuda.manual_seed(123) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True batch_size = args.batch_size if batch_size <= 0: batch_size = 128 if args.gpus == 0 else 8192 print('Using batch size {}'.format(batch_size)) print('Smart fen skipping: {}'.format(args.smart_fen_skipping)) print('Random fen skipping: {}'.format(args.random_fen_skipping)) if args.threads > 0: print('limiting torch to {} threads.'.format(args.threads)) t_set_num_threads(args.threads) feature_set = features.get_feature_set_from_name(args.features) if args.py_data: print('Using python data loader') train_data, val_data = data_loader_py(args.train, args.val, batch_size, feature_set, 'cuda:0') else: print('Using c++ data loader') train_data, val_data = data_loader_cc( args.train, args.val, feature_set, args.num_workers, batch_size, args.smart_fen_skipping, args.random_fen_skipping, 'cuda:0') print("Feature set: {}".format(feature_set.name)) print("Num real features: {}".format(feature_set.num_real_features)) print("Num virtual features: {}".format(feature_set.num_virtual_features)) print("Num features: {}".format(feature_set.num_features)) START_EPOCH = 0 NUM_EPOCHS = 150 SWA_START = int(0.75 * NUM_EPOCHS) LEARNING_RATE = 5e-4 DECAY = 0 EPS = 1e-7 best_loss = 1000 is_best = False early_stopping_delay = 30 early_stopping_count = 0 early_stopping_flag = False summary_location = 'logs/nnue_experiment_' + args.experiment save_location = '/home/esigelec/PycharmProjects/nnue-pytorch/save_models/' + args.experiment writer = SummaryWriter(summary_location) nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, s=1) train_params = [{ 'params': nnue.get_1xlr(), 'lr': LEARNING_RATE }, { 'params': nnue.get_10xlr(), 'lr': LEARNING_RATE * 10.0 }] optimizer = ranger.Ranger(train_params, lr=LEARNING_RATE, eps=EPS, betas=(0.9, 0.999), weight_decay=DECAY) if args.resume_from_model is not None: nnue, optimizer, START_EPOCH = load_ckp(args.resume_from_model, nnue, optimizer) nnue.set_feature_set(feature_set) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7, cooldown=1, min_lr=1e-7, verbose=True) swa_scheduler = SWALR(optimizer, annealing_epochs=5, swa_lr=[5e-5, 1e-4]) nnue = nnue.cuda() swa_nnue = AveragedModel(nnue) for epoch in range(START_EPOCH, NUM_EPOCHS): nnue.train() train_interval = 100 loss_f_sum_interval = 0.0 loss_f_sum_epoch = 0.0 loss_v_sum_epoch = 0.0 if early_stopping_flag: print("early end of training at epoch" + str(epoch)) break for batch_idx, batch in enumerate(train_data): batch = [_data.cuda() for _data in batch] us, them, white, black, outcome, score = batch optimizer.zero_grad() output = nnue(us, them, white, black) loss = nnue_loss(output, outcome, score, args.lambda_) loss.backward() torch.nn.utils.clip_grad_norm_(nnue.parameters(), 0.5) optimizer.step() loss_f_sum_interval += loss.float() loss_f_sum_epoch += loss.float() if batch_idx % train_interval == train_interval - 1: writer.add_scalar('train_loss', loss_f_sum_interval / train_interval, epoch * len(train_data) + batch_idx) loss_f_sum_interval = 0.0 print("Epoch #{}\t Train_Loss: {:.8f}\t".format( epoch, loss_f_sum_epoch / len(train_data))) if epoch % 1 == 0 or (epoch + 1) == NUM_EPOCHS: with torch.no_grad(): nnue.eval() for batch_idx, batch in enumerate(val_data): batch = [_data.cuda() for _data in batch] us, them, white, black, outcome, score = batch _output = nnue(us, them, white, black) loss_v = nnue_loss(_output, outcome, score, args.lambda_) loss_v_sum_epoch += loss_v.float() if epoch > SWA_START: print("swa_mode") swa_nnue.update_parameters(nnue) swa_scheduler.step() checkpoint = { 'epoch': epoch + 1, 'state_dict': swa_nnue.state_dict(), 'optimizer': optimizer.state_dict() } save_ckp(checkpoint, save_location, 'swa_nnue.pt') else: scheduler.step(loss_v_sum_epoch / len(val_data)) if loss_v_sum_epoch / len(val_data) <= best_loss: best_loss = loss_v_sum_epoch / len(val_data) is_best = True early_stopping_count = 0 else: early_stopping_count += 1 if early_stopping_delay == early_stopping_count: early_stopping_flag = True if is_best: checkpoint = { 'epoch': epoch + 1, 'state_dict': nnue.state_dict(), 'optimizer': optimizer.state_dict() } save_ckp(checkpoint, save_location) is_best = False writer.add_scalar('val_loss', loss_v_sum_epoch / len(val_data), epoch * len(train_data) + batch_idx) print("Epoch #{}\tVal_Loss: {:.8f}\t".format( epoch, loss_v_sum_epoch / len(val_data))) loss_v_sum_epoch = 0.0 with torch.no_grad(): swa_nnue.eval() for batch_idx, batch in enumerate(val_data): batch = [_data.cuda() for _data in batch] us, them, white, black, outcome, score = batch _output = swa_nnue(us, them, white, black) loss_v = nnue_loss(_output, outcome, score, args.lambda_) loss_v_sum_epoch += loss_v.float() print("Val_Loss: {:.8f}\t".format(loss_v_sum_epoch / len(val_data))) writer.close()