예제 #1
0
    def __init__(self, f, feature_set):
        self.f = f
        self.feature_set = feature_set
        self.model = M.NNUE(feature_set)
        fc_hash = NNUEWriter.fc_hash(self.model)

        self.read_header(feature_set, fc_hash)
        self.read_int32(feature_set.hash ^
                        (M.L1 * 2))  # Feature transformer hash
        self.read_feature_transformer(self.model.input,
                                      self.model.num_psqt_buckets)
        for i in range(self.model.num_ls_buckets):
            l1 = nn.Linear(2 * M.L1, M.L2)
            l2 = nn.Linear(M.L2, M.L3)
            output = nn.Linear(M.L3, 1)
            self.read_int32(fc_hash)  # FC layers hash
            self.read_fc_layer(l1)
            self.read_fc_layer(l2)
            self.read_fc_layer(output, is_output=True)
            self.model.layer_stacks.l1.weight.data[i * M.L2:(i + 1) *
                                                   M.L2, :] = l1.weight
            self.model.layer_stacks.l1.bias.data[i * M.L2:(i + 1) *
                                                 M.L2] = l1.bias
            self.model.layer_stacks.l2.weight.data[i * M.L3:(i + 1) *
                                                   M.L3, :] = l2.weight
            self.model.layer_stacks.l2.bias.data[i * M.L3:(i + 1) *
                                                 M.L3] = l2.bias
            self.model.layer_stacks.output.weight.data[i:(
                i + 1), :] = output.weight
            self.model.layer_stacks.output.bias.data[i:(i + 1)] = output.bias
예제 #2
0
def main():
  config = C.Config('config.yaml')

  sample_to_device = lambda x: tuple(map(lambda t: t.to(config.device, non_blocking=True), x))

  M = model.NNUE().to(config.device)

  if (path.exists(config.model_save_path)):
    print('Loading model ... ')
    M.load_state_dict(torch.load(config.model_save_path))

  train_data = ataxx_dataset.AtaxxData(config.train_ataxx_data_path, config)
  validation_data = ataxx_dataset.AtaxxData(config.validation_ataxx_data_path, config)
  
  train_data_loader = torch.utils.data.DataLoader(train_data,\
    batch_size=config.batch_size,\
    num_workers=config.num_workers,\
    pin_memory=True,\
    worker_init_fn=ataxx_dataset.worker_init_fn)

  validation_data_loader = torch.utils.data.DataLoader(validation_data, batch_size=config.batch_size)
  validation_data_loader_iter = iter(validation_data_loader)

  writer = SummaryWriter(config.visual_directory)

  writer.add_graph(M, sample_to_device(next(iter(train_data_loader)))[:3])

  opt = optim.Adadelta(M.parameters(), lr=config.learning_rate)
  scheduler = optim.lr_scheduler.StepLR(opt, 1, gamma=0.5)

  queue = []
  
  for epoch in range(1, config.epochs + 1):
    for i, sample in enumerate(train_data_loader):
      # update visual data
      if (i % config.test_rate) == 0 and i != 0:
        step = train_data.cardinality() * (epoch - 1) + i * config.batch_size
        train_loss = sum(queue) / len(queue)
        
        validation_sample = next(validation_data_loader_iter, None)
        if validation_sample == None:
          validation_data_loader_iter = iter(validation_data_loader)
          validation_sample = next(validation_data_loader_iter, None)
        validation_sample = sample_to_device(validation_sample)
        
        validation_loss = get_validation_loss(M, validation_sample)
        
        writer.add_scalar('train_loss', train_loss, step)
        writer.add_scalar('validation_loss', validation_loss, step)
      
      if (i % config.save_rate) == 0 and i != 0:
        print('Saving model ...')
        M.to_binary_file(config.bin_model_save_path)
        torch.save(M.state_dict(), config.model_save_path)

      train_step(M, sample_to_device(sample), opt, queue, max_queue_size=config.max_queue_size, report=(0 == i % config.report_rate))
    scheduler.step()
예제 #3
0
def main():
  parser = argparse.ArgumentParser(description="Trains the network.")
  parser.add_argument("train", help="Training data (.bin or .binpack)")
  parser.add_argument("val", help="Validation data (.bin or .binpack)")
  parser.add_argument("--architecture", default='normal', help="architecture of model")
  parser = pl.Trainer.add_argparse_args(parser)
  parser.add_argument("--py-data", action="store_true", help="Use python data loader (default=False)")
  parser.add_argument("--lambda", default=1.0, type=float, dest='lambda_', help="lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0).")
  parser.add_argument("--num-workers", default=1, type=int, dest='num_workers', help="Number of worker threads to use for data loading. Currently only works well for binpack.")
  parser.add_argument("--batch-size", default=-1, type=int, dest='batch_size', help="Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128.")
  parser.add_argument("--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .")
  parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.")
  parser.add_argument("--smart-fen-skipping", action='store_true', dest='smart_fen_skipping', help="If enabled positions that are bad training targets will be skipped during loading. Default: False")
  args = parser.parse_args()

  if args.architecture.lower() == "leiser":
      data_name = halfkp.LEISER_NAME
      model_inputs = halfkp.LEISER_INPUTS
  elif args.architecture.lower() == "normal":
      data_name = halfkp.NAME
      model_inputs = halfkp.INPUTS
  else:
      raise Exception("Incorrect architecture name")

  nnue = M.NNUE(num_inputs=model_inputs, lambda_=args.lambda_)

  print("Training with {} validating with {}".format(args.train, args.val))

  pl.seed_everything(args.seed)
  print("Seed {}".format(args.seed))


  batch_size = args.batch_size
  if batch_size <= 0:
    batch_size = 128 if args.gpus == 0 else 8192
  print('Using batch size {}'.format(batch_size))

  print('Smart fen skipping: {}'.format(args.smart_fen_skipping))

  if args.threads > 0:
    print('limiting torch to {} threads.'.format(args.threads))
    t_set_num_threads(args.threads)

  if args.py_data:
    print('Using python data loader')
    train, val = data_loader_py(args.train, args.val, batch_size)
  else:
    print('Using c++ data loader')
    train, val = data_loader_cc(args.train, args.val, data_name, args.num_workers, batch_size, args.smart_fen_skipping)

  logdir = args.default_root_dir if args.default_root_dir else 'logs/'
  print('Using log dir {}'.format(logdir), flush=True)

  tb_logger = pl_loggers.TensorBoardLogger(logdir)
  checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=1, save_last=True, monitor='val_loss', filename='best_model')
  trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger)
  trainer.fit(nnue, train, val)
예제 #4
0
    def __init__(self, f, num_inputs):
        self.f = f
        self.model = M.NNUE(num_inputs=num_inputs)

        self.read_header()
        self.read_int32(0x5d69d7b8)  # Feature transformer hash
        self.read_feature_transformer(self.model.input)
        self.read_int32(0x63337156)  # FC layers hash
        self.read_fc_layer(self.model.l1)
        self.read_fc_layer(self.model.l2)
        self.read_fc_layer(self.model.output, is_output=True)
예제 #5
0
  def __init__(self, f, feature_set):
    self.f = f
    self.feature_set = feature_set
    self.model = M.NNUE(feature_set)

    self.read_header(feature_set)
    self.read_int32(feature_set.hash) # Feature transformer hash
    self.read_feature_transformer(self.model.input)
    self.read_int32(FC_HASH) # FC layers hash
    self.read_fc_layer(self.model.l1)
    self.read_fc_layer(self.model.l2)
    self.read_fc_layer(self.model.output, is_output=True)
예제 #6
0
    def __init__(self, f, feature_set):
        self.f = f
        self.feature_set = feature_set
        self.model = M.NNUE(feature_set)
        fc_hash = NNUEWriter.fc_hash(self.model)

        self.read_header(feature_set, fc_hash)
        self.read_int32(feature_set.hash ^
                        (M.L1 * 2))  # Feature transformer hash
        self.read_feature_transformer(self.model.input)
        self.read_int32(fc_hash)  # FC layers hash
        self.read_fc_layer(self.model.l1)
        self.read_fc_layer(self.model.l2)
        self.read_fc_layer(self.model.output, is_output=True)
예제 #7
0
def main():
    parser = argparse.ArgumentParser(
        description="Converts files between ckpt and nnue format.")
    parser.add_argument("source",
                        help="Source file (can be .ckpt, .pt or .nnue)")
    parser.add_argument("target", help="Target file (can be .pt or .nnue)")
    parser.add_argument("--architecture",
                        default="normal",
                        help="model architecture (leiser or normal)")
    args = parser.parse_args()

    if args.architecture.lower() == "leiser":
        num_inputs = halfkp.LEISER_INPUTS
    elif args.architecture.lower() == "normal":
        num_inputs = halfkp.INPUTS
    else:
        raise Exception("Non-valid architecture selection" + args.architecture)

    print('Converting %s to %s' % (args.source, args.target))

    if args.source.endswith(".pt") or args.source.endswith(".ckpt"):
        if not args.target.endswith(".nnue"):
            raise Exception("Target file must end with .nnue")
        if args.source.endswith(".pt"):
            nnue = torch.load(args.source)
        else:
            nnue = M.NNUE(num_inputs=num_inputs)
            checkpoint = torch.load(args.source)
            nnue.load_state_dict(checkpoint['state_dict'])
        nnue.eval()
        #test(nnue)
        writer = NNUEWriter(nnue)
        with open(args.target, 'wb') as f:
            f.write(writer.buf)
    elif args.source.endswith(".nnue"):
        if not args.target.endswith(".pt"):
            raise Exception("Target file must end with .pt")
        with open(args.source, 'rb') as f:
            reader = NNUEReader(f, num_inputs)
        torch.save(reader.model, args.target)
    else:
        raise Exception('Invalid filetypes: ' + str(args))
예제 #8
0
파일: train.py 프로젝트: tdh1967/seer-nnue
def main():
  config = C.Config('config.yaml')

  sample_to_device = lambda x: tuple(map(lambda t: t.to(config.device, non_blocking=True), x))

  M = model.NNUE().to(config.device)

  if (path.exists(config.model_save_path)):
    print('Loading model ... ')
    M.load_state_dict(torch.load(config.model_save_path))

  data = nnue_bin_dataset.NNUEBinData(config)
  data_loader = torch.utils.data.DataLoader(data,\
    batch_size=config.batch_size,\
    num_workers=config.num_workers,\
    pin_memory=True,\
    worker_init_fn=nnue_bin_dataset.worker_init_fn)

  opt = optim.Adadelta(M.parameters(), lr=config.learning_rate)
  scheduler = optim.lr_scheduler.StepLR(opt, 1, gamma=0.5)

  loss_history = []
  queue = []
  
  for epoch in range(1, config.epochs + 1):
    for i, sample in enumerate(data_loader):
      # update visual data
      if (i % config.test_rate) == 0 and i != 0:
        loss_history.append(sum(queue) / len(queue))
        plt.clf()
        plt.plot(loss_history)
        plt.savefig('{}/loss_graph.png'.format(config.visual_directory), bbox_inches='tight')
      
      if (i % config.save_rate) == 0 and i != 0:
        print('Saving model ...')
        M.to_binary_file(config.bin_model_save_path)
        torch.save(M.state_dict(), config.model_save_path)

      train_step(M, sample_to_device(sample), opt, queue, max_queue_size=config.max_queue_size, lambda_=config.lambda_, report=(0 == i % config.report_rate))

    scheduler.step()
예제 #9
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping_deprecated',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored"
    )
    parser.add_argument(
        "--no-smart-fen-skipping",
        action='store_true',
        dest='no_smart_fen_skipping',
        help=
        "If used then no smart fen skipping will be done. By default smart fen skipping is done."
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=3,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
        nnue.cuda()
    else:
        nnue = torch.load(args.resume_from_model)
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.cuda()

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 16384
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    tb_logger = pl_loggers.TensorBoardLogger(logdir)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       period=20,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=tb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    print('Using c++ data loader')
    train, val = make_data_loaders(args.train, args.val, feature_set,
                                   args.num_workers, batch_size,
                                   not args.no_smart_fen_skipping,
                                   args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
예제 #10
0
from os import path
import torch
import ataxx

import config as C
import util
import model

config = C.Config('config.yaml')

M = model.NNUE().to(config.device)

if (path.exists(config.model_save_path)):
  print('Loading model ... ')
  M.load_state_dict(torch.load(config.model_save_path, map_location=config.device))

num_parameters = sum(map(lambda x: torch.numel(x), M.parameters()))

print(num_parameters)

M.cpu()

while True:
  bd = ataxx.Board(input("fen: "))
  w, b = util.to_tensors(bd)
  white, black = util.to_tensors(bd)
  val = M(torch.tensor([bd.turn]).float(), white.unsqueeze(0).float(), black.unsqueeze(0).float())
  print(val)
예제 #11
0
def main(args):
    # Select which device to use
    if torch.cuda.is_available():
        main_device = 'cuda:0'
    else:
        main_device = 'cpu'

    # Create directories to store data and logs in
    output_path = prepare_output_directory()
    log_path = prepare_log_directory()

    # Print configuration info
    print(f'Device: {main_device}')
    print(f'Training set: {args.train}')
    print(f'Validation set: {args.val}')
    print(f'Batch size: {args.batch_size}')
    print(f'Using factorizer: {args.use_factorizer}')
    print(f'Lambda: {args.lambda_}')
    print(f'Validation check interval: {args.val_check_interval}')
    print(f'Logs written to: {log_path}')
    print(f'Data written to: {output_path}')
    print('')

    # Create log writer
    writer = SummaryWriter(log_path)

    # Create data loaders
    train_data_loader, val_data_loader = create_data_loaders(
        args.train, args.val, args.train_size, args.val_size, args.batch_size,
        args.use_factorizer, main_device)

    # Create model
    nnue = M.NNUE(args.use_factorizer,
                  feature_set=halfkp.Features()).to(main_device)

    # Configure optimizer
    optimizer = ranger.Ranger(nnue.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=1,
                                                           verbose=True,
                                                           min_lr=1e-6)

    # Main training loop
    num_batches = len(train_data_loader)
    epoch = 0
    running_train_loss = 0.0
    while True:
        best_val_loss = 1000000.0

        for k, sample in enumerate(train_data_loader):
            train_loss = train_step(nnue, sample, optimizer, args.lambda_,
                                    epoch, k, num_batches)
            running_train_loss += train_loss.item()

            if k % args.val_check_interval == (args.val_check_interval - 1):
                val_loss = calculate_validation_loss(nnue, val_data_loader,
                                                     args.lambda_)
                new_best = False
                if (val_loss < best_val_loss):
                    new_best = True
                    best_val_loss = val_loss
                save_model(nnue, output_path, epoch, k, val_loss, new_best,
                           False)
                writer.add_scalar('training loss',
                                  running_train_loss / args.val_check_interval,
                                  epoch * num_batches + k)
                writer.add_scalar('validation loss', val_loss,
                                  epoch * num_batches + k)
                running_train_loss = 0.0

        val_loss = calculate_validation_loss(nnue, val_data_loader,
                                             args.lambda_)
        new_best = False
        if (val_loss < best_val_loss):
            new_best = True
            best_val_loss = val_loss
        save_model(nnue, output_path, epoch, num_batches - 1, val_loss,
                   new_best, True)
        print('')

        scheduler.step(val_loss)
        epoch += 1
예제 #12
0
def main():

    t_set_printoptions(profile="full")

    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    args = parser.parse_args()

    nnue = M.NNUE(lambda_=args.lambda_)

    print("Training with {} validating with {}".format(args.train, args.val))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 1024 if args.gpus == 0 else 2048
    print('Using batch size {}'.format(batch_size))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    tb_logger = pl_loggers.TensorBoardLogger('logs/')
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=tb_logger,
                                            profiler='advanced')

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    if args.py_data:
        print('Using python data loader')
        train, val = data_loader_py(args.train, args.val, args.num_workers,
                                    batch_size)
    else:
        print('Using c++ data loader')
        train, val = data_loader_cc(args.train, args.val, args.num_workers,
                                    batch_size, args.smart_fen_skipping,
                                    args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)
예제 #13
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument("--alpha",
                        default=1.0,
                        type=float,
                        dest='alpha_',
                        help="random multiply factor (default=1.0).")
    parser.add_argument(
        "--beta",
        default=6000,
        type=int,
        dest='beta_',
        help=
        "definite random step frequency - according to steps (default=6000).")
    parser.add_argument(
        "--gamma",
        default=0.0005,
        type=float,
        dest='gamma_',
        help="randomized random step frequency (default=0.0005).")
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    if not os.path.exists(args.train):
        raise Exception('{0} does not exist'.format(args.train))
    if not os.path.exists(args.val):
        raise Exception('{0} does not exist'.format(args.val))

    feature_set = features.get_feature_set_from_name(args.features)

    if args.resume_from_model is None:
        nnue = M.NNUE(feature_set=feature_set,
                      lambda_=args.lambda_,
                      alpha_=args.alpha_,
                      beta_=args.beta_,
                      gamma=args.gamma_)
    else:
        nnue = torch.load(args.resume_from_model)
        nnue.set_feature_set(feature_set)
        nnue.lambda_ = args.lambda_
        nnue.alpha_ = args.alpha_
        nnue.beta_ = args.beta_
        nnue.gamma_ = args.gamma_

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    print("Training with {} validating with {}".format(args.train, args.val))

    pl.seed_everything(args.seed)
    print("Seed {}".format(args.seed))

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    logdir = args.default_root_dir if args.default_root_dir else 'logs/'
    print('Using log dir {}'.format(logdir), flush=True)

    wandb_logger = WandbLogger()
    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True,
                                                       period=5,
                                                       save_top_k=-1)
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback],
                                            logger=wandb_logger)

    main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(
        trainer.root_gpu)

    if args.py_data:
        print('Using python data loader')
        train, val = data_loader_py(args.train, args.val, feature_set,
                                    batch_size, main_device)
    else:
        print('Using c++ data loader')
        train, val = data_loader_cc(args.train, args.val, feature_set,
                                    args.num_workers, batch_size,
                                    args.smart_fen_skipping,
                                    args.random_fen_skipping, main_device)

    trainer.fit(nnue, train, val)