class Coach(): """ This class executes the self-play + learning. It uses the functions defined in Game and NeuralNet. args are specified in main.py. """ def __init__(self, game, nnet, args): self.game = game self.args = args self.nnet = nnet self.pnet = self.nnet.__class__(self.game, self.args) # the competitor network self.mcts = MCTS(self.game, self.nnet, self.args) self.trainExamplesHistory = [ ] # history of examples from args.numItersForTrainExamplesHistory latest iterations self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples() self.elo = 0 # elo score of the current model self.logger = logging.getLogger(self.__class__.__name__) start_time = datetime.datetime.now().strftime('%m%d_%H%M%S') # setup visualization writer instance writer_dir = os.path.join(self.args.log_dir, self.args.name, start_time) self.writer = WriterTensorboardX(writer_dir, self.logger, self.args.tensorboardX) def executeEpisode(self): """ This function executes one episode of self-play, starting with player 1. As the game is played, each turn is added as a training example to trainExamples. The game is played till the game ends. After the game ends, the outcome of the game is used to assign values to each example in trainExamples. It uses a temp=1 if episodeStep < tempThreshold, and thereafter uses temp=0. Returns: trainExamples: a list of examples of the form (canonicalBoard,pi,v) pi is the MCTS informed policy vector, v is +1 if the player eventually won the game, else -1. """ self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree trainExamples = [] board = self.game.getInitBoard() self.curPlayer = 1 episodeStep = 0 while True: episodeStep += 1 canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer) temp = int(episodeStep < self.args.tempThreshold) pi = self.mcts.getActionProb(canonicalBoard, temp=temp) sym = self.game.getSymmetries(canonicalBoard, pi) for b, p in sym: trainExamples.append([b, self.curPlayer, p, None]) action = np.random.choice(len(pi), p=pi) board, self.curPlayer = self.game.getNextState( board, self.curPlayer, action) r = self.game.getGameEnded(board, self.curPlayer) if r != 0: return [(x[0], x[2], r * ((-1)**(x[1] != self.curPlayer))) for x in trainExamples] def learn(self): """ Performs numIters iterations with numEps episodes of self-play in each iteration. After every iteration, it retrains neural network with examples in trainExamples (which has a maximium length of maxlenofQueue). It then pits the new neural network against the old one and accepts it only if it wins >= updateThreshold fraction of games. """ for i in tqdm(range(1, self.args.numIters + 1), desc='Iteration'): # examples of the iteration if not self.skipFirstSelfPlay or i > 1: iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue) for eps in tqdm(range(self.args.numEps), desc='mcts.Episode'): iterationTrainExamples += self.executeEpisode() # save the iteration examples to the history self.trainExamplesHistory.append(iterationTrainExamples) if len(self.trainExamplesHistory ) > self.args.numItersForTrainExamplesHistory: print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples") self.trainExamplesHistory.pop(0) # backup history to a file # NB! the examples were collected using the model from the previous iteration, so (i-1) self.saveTrainExamples(i - 1) # shuffle examples before training trainExamples = [] for e in self.trainExamplesHistory: trainExamples.extend(e) shuffle(trainExamples) # training new network, keeping a copy of the old one self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') pmcts = MCTS(self.game, self.pnet, self.args) self.nnet.train(trainExamples, self.writer) self.writer.set_step(i - 1, "learning") nmcts = MCTS(self.game, self.nnet, self.args) print("PITTING AGAINST METRIC COMPONENTS") for metric_opponent in self.args.metric_opponents: arena = Arena( lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), metric_opponent(self.game).play, self.game) nwins, owins, draws = arena.playGames( self.args.metricArenaCompare) print('%s WINS : %d / %d ; DRAWS : %d' % (metric_opponent.__name__, nwins, owins, draws)) if nwins + owins == 0: win_prct = 0 else: win_prct = float(nwins) / (nwins + owins) self.writer.add_scalar( '{}_win'.format(metric_opponent.__name__), win_prct) # Reset nmcts nmcts = MCTS(self.game, self.nnet, self.args) print('PITTING AGAINST PREVIOUS VERSION') arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)), lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game) pwins, nwins, draws = arena.playGames(self.args.arenaCompare) if nwins + pwins == 0: win_prct = 0 else: win_prct = float(nwins) / (nwins + pwins) self.writer.add_scalar('self_win', win_prct) # Calculate elo score for self play results = [-x for x in arena.get_results() ] # flip to be next neural network wins nelo, pelo = elo(self.elo, self.elo, results) print('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws)) if pwins + nwins == 0 or float(nwins) / ( pwins + nwins) < self.args.updateThreshold: print('REJECTING NEW MODEL') self.elo = pelo self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') else: print('ACCEPTING NEW MODEL') self.elo = nelo self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i)) self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar') self.writer.add_scalar('self_elo', self.elo) def getCheckpointFile(self, iteration): return 'checkpoint_' + str(iteration) + '.pth.tar' def saveTrainExamples(self, iteration): folder = self.args.checkpoint if not os.path.exists(folder): os.makedirs(folder) filename = os.path.join( folder, self.getCheckpointFile(iteration) + ".examples") with open(filename, "wb+") as f: Pickler(f).dump(self.trainExamplesHistory) f.closed def loadTrainExamples(self): modelFile = os.path.join(self.args.load_folder_file[0], self.args.load_folder_file[1]) examplesFile = modelFile + ".examples" if not os.path.isfile(examplesFile): print(examplesFile) r = input("File with trainExamples not found. Continue? [y|n]") if r != "y": sys.exit() else: print("File with trainExamples found. Read it.") with open(examplesFile, "rb") as f: self.trainExamplesHistory = Unpickler(f).load() f.closed # examples based on the model were already collected (loaded) self.skipFirstSelfPlay = True
class BaseTrainer: """ Base class for all trainers """ def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger=None): self.config = config self.logger = logging.getLogger(self.__class__.__name__) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) self.loss = loss self.metrics = metrics self.optimizer = optimizer self.train_logger = train_logger cfg_trainer = config['train'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_p'] self.verbosity = cfg_trainer['verbosity'] self.monitor = cfg_trainer.get('monitor', 'off') # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = math.inf if self.mnt_mode == 'min' else -math.inf self.early_stop = cfg_trainer.get('early_stop', math.inf) self.start_epoch = 1 # setup directory for checkpoint saving start_time = datetime.datetime.now().strftime('%m%d_%H%M%S') self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], start_time, 'checkpoints') self.log_dir = os.path.join(cfg_trainer['save_dir'], start_time, 'logs') self.writer = WriterTensorboardX(self.log_dir, self.logger, cfg_trainer['tbX']) # Save configuration file into checkpoint directory: mkdir_p(self.checkpoint_dir) if self.config.get('cfg', None) is not None: cfg_save_path = os.path.join(self.checkpoint_dir, 'model.cfg') with open(cfg_save_path, 'w') as fw: fw.write(open(self.config['cfg']).read()) self.config['cfg'] = cfg_save_path config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=False) if resume: self._resume_checkpoint(resume) def train(self): """ Full training logic """ best_df = None not_improved_count = 0 #f = open(os.path.join(self.log_dir, 'lr.txt'), 'w') for epoch in range(self.start_epoch, self.epochs + 1): # _train_epoch returns dict with train metrics ("metrics"), validation # metrics ("val_metrics") and other key,value pairs. Store/update them in log. result = self._train_epoch(epoch) # save logged informations into log dict log = {'epoch': epoch} for key, value in result.items(): if key == 'metrics': log.update({ mtr.__name__: value[i] for i, mtr in enumerate(self.metrics) }) elif key == 'val_metrics': log.update({ 'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics) }) else: log[key] = value c_lr = self.optimizer.param_groups[0]['lr'] # print logged informations to the screen if self.train_logger is not None: self.train_logger.add_entry(log) if self.verbosity >= 1: df = pd.DataFrame.from_dict([log]).T df.columns = [''] #self.logger.info('Epoch: {}'.format(epoch)) self.logger.info('{}'.format(df.loc[df.index != 'epoch'])) self.logger.info('lr_0: {}'.format(c_lr)) #f.write('%.5f\t%.5f\t%.5f\n'%(c_lr, result['loss'], result['metrics'][0])) #f.flush() self.writer.add_scalar('lr', c_lr) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. Model performance monitoring is disabled." .format(self.mnt_metric)) self.mnt_mode = 'off' improved = False not_improved_count = 0 if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True best_df = df else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. Training stops." .format(self.early_stop)) self.logger.info('Final:\n{}'.format( best_df.loc[best_df.index != 'epoch'])) break if len(self.writer) > 0: self.logger.info( '\nRun TensorboardX:\ntensorboard --logdir={}\n'.format( self.log_dir)) if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) #self.logger.info('\n\n\tTensorboardX Path: {}\n'.format(self.log_dir)) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Current epoch number """ raise NotImplementedError def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'logger': self.train_logger, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, 'config': self.config, 'classes': self.model.classes } filename = os.path.join(self.checkpoint_dir, 'checkpoint-current.pth') #filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{}.pth'.format(epoch)) torch.save(state, filename) self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') torch.save(state, best_path) self.logger.info( "Saving current best: {} ...".format('model_best.pth')) self.logger.info("[IMPROVED]") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] # load architecture params from checkpoint. self.model.load_state_dict(checkpoint['state_dict'], ) # load optimizer state from checkpoint only when optimizer type is not changed. self.optimizer.load_state_dict(checkpoint['optimizer']) self.train_logger = checkpoint['logger'] self.logger.info("Checkpoint '{}' (epoch {}) loaded".format( resume_path, self.start_epoch))
def train(self, examples, writer=None): """ examples: list of examples, each example is of form (board, pi, v) writer: optional tensorboardX writer """ optimizer = self.args.optimizer(self.nnet.parameters(), lr=self.args.lr, **self.args.optimizer_kwargs) scheduler = self.args.lr_scheduler(optimizer, **self.args.lr_scheduler_kwargs) # If no writer, create unusable writer if writer is None: writer = WriterTensorboardX(None, None, False) epoch_bar = tqdm(desc="Training Epoch", total=self.args.epochs) for epoch in range(self.args.epochs): self.nnet.train() scheduler.step() pi_losses = AverageMeter() v_losses = AverageMeter() total_losses = AverageMeter() num_batches = int(len(examples)/self.args.batch_size) bar = tqdm(desc='Batch', total=num_batches) batch_idx = 0 while batch_idx < num_batches: writer.set_step((self.train_iteration * self.args.epochs * num_batches) + (epoch * num_batches) + batch_idx) sample_ids = np.random.randint(len(examples), size=self.args.batch_size) boards, pis, vs = list(zip(*[examples[i] for i in sample_ids])) boards = torch.FloatTensor(np.array(boards).astype(np.float64)) target_pis = torch.FloatTensor(np.array(pis)) target_vs = torch.FloatTensor(np.array(vs).astype(np.float64)) # predict if self.args.cuda: boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda() # compute output out_pi, out_v = self.nnet(boards) l_pi = self.loss_pi(target_pis, out_pi) l_v = self.loss_v(target_vs, out_v) total_loss = l_pi + l_v pi_losses.update(l_pi.item(), boards.size(0)) v_losses.update(l_v.item(), boards.size(0)) total_losses.update(total_loss.item(), boards.size(0)) # record loss writer.add_scalar('pi_loss', l_pi.item()) writer.add_scalar('v_loss', l_v.item()) writer.add_scalar('loss', total_loss.item()) # compute gradient and do SGD step optimizer.zero_grad() total_loss.backward() optimizer.step() # measure elapsed time batch_idx += 1 # plot progress bar.set_postfix( lpi=l_pi.item(), lv=l_v.item(), loss=total_loss.item() ) bar.update() bar.close() writer.set_step((self.train_iteration * self.args.epochs) + epoch, 'train_epoch') writer.add_scalar('epoch_pi_loss', pi_losses.avg) writer.add_scalar('epoch_v_loss', v_losses.avg) writer.add_scalar('epoch_loss', total_losses.avg) epoch_bar.set_postfix( avg_lpi=pi_losses.avg, avg_lv=v_losses.avg, avg_l=total_losses.avg ) epoch_bar.update() epoch_bar.close() self.train_iteration += 1