class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.num_preload = args.num_preload self.lr = self.config['lr'] ut.remove_files_in_dir(self.config['save_to']) self.logger = ut.get_logger(self.config['log_file']) self.train_smooth_perps = [] self.train_true_perps = [] # For logging self.log_freq = self.config[ 'log_freq'] # log train stat every this-many batches self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] self.total_batches = 0 # number of batches done for the whole training self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model device = ut.get_device() self.model = Model(self.config).to(device) self.validator = Validator(self.config, self.model) self.validate_freq = self.config['validate_freq'] if self.validate_freq == 1: self.logger.info('Evaluate every ' + ( 'epoch' if self.config['val_per_epoch'] else 'batch')) else: self.logger.info(f'Evaluate every {self.validate_freq:,} ' + ( 'epochs' if self.config['val_per_epoch'] else 'batches')) # Estimated number of batches per epoch self.est_batches = max(self.model.data_manager.training_tok_counts ) // self.config['batch_size'] self.logger.info( f'Guessing around {self.est_batches:,} batches per epoch') param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info(f'Model has {int(param_count):,} parameters') # Set up parameter-specific options params = [] for p in self.model.parameters(): ptr = p.data_ptr() d = {'params': [p]} if ptr in self.model.parameter_attrs: attrs = self.model.parameter_attrs[ptr] for k in attrs: d[k] = attrs[k] params.append(d) self.optimizer = torch.optim.Adam(params, lr=self.lr, betas=(self.config['beta1'], self.config['beta2']), eps=self.config['epsilon']) def report_epoch(self, epoch, batches): self.logger.info(f'Finished epoch {epoch}') self.logger.info(f' Took {ut.format_time(self.epoch_time)}') self.logger.info( f' avg words/sec {self.epoch_weights / self.epoch_time:.2f}') self.logger.info(f' avg sec/batch {self.epoch_time / batches:.2f}') self.logger.info(f' {batches} batches') if self.epoch_weights: train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights else: train_smooth_perp = float('inf') train_true_perp = float('inf') self.est_batches = batches self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( f' smooth, true perp: {float(train_smooth_perp):.2f}, {float(train_true_perp):.2f}' ) def clip_grad_values(self): """ Adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_value_ This is the same as torch.nn.utils.clip_grad_value_, except is also sets nan gradients to 0.0 """ parameters = self.model.parameters() clip_value = float(self.config['grad_clamp']) if isinstance(parameters, torch.Tensor): parameters = [parameters] for p in filter(lambda p: p.grad is not None, parameters): p.grad.data.clamp_(min=-clip_value, max=clip_value) p.grad.data[torch.isnan(p.grad.data)] = 0.0 def get_params(self, pe=False): for n, p in self.model.named_parameters(): if (n in self.model.struct_params) == pe: yield p def run_log(self, batch, epoch, batch_data): #with torch.autograd.detect_anomaly(): # throws exception when any forward computation produces nan start = time.time() _, src_toks, src_structs, trg_toks, targets = batch_data # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks, src_structs, trg_toks, targets, batch, epoch) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.config['normalize_loss'] == ac.LOSS_TOK: opt_loss = loss / (targets != ac.PAD_ID).sum() elif self.config['normalize_loss'] == ac.LOSS_BATCH: opt_loss = loss / targets.size()[0] else: opt_loss = loss opt_loss.backward() # clip gradient if self.config['grad_clamp']: self.clip_grad_values() if self.config['grad_clip_pe']: pms = list(self.get_params(True)) if pms: torch.nn.utils.clip_grad_norm_(pms, self.config['grad_clip_pe']) pms = self.get_params() else: pms = self.model.parameters() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['grad_clip']).detach() # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().sum() loss = loss.detach() nll_loss = nll_loss.detach() self.total_batches += 1 self.log_train_loss.append(loss) self.log_nll_loss.append(nll_loss) self.log_train_weights.append(num_words) self.log_grad_norms.append(grad_norm) self.epoch_time += time.time() - start if self.total_batches % self.log_freq == 0: log_train_loss = torch.tensor(0.0) log_nll_loss = torch.tensor(0.0) log_train_weights = torch.tensor(0.0) log_all_weights = torch.tensor(0.0) for smooth, nll, weight in zip(self.log_train_loss, self.log_nll_loss, self.log_train_weights): if not self.config['grad_clamp'] or (torch.isfinite(smooth) and torch.isfinite(nll)): log_train_loss += smooth log_nll_loss += nll log_train_weights += weight log_all_weights += weight #log_train_loss = sum(x for x in self.log_train_loss).item() #log_nll_loss = sum(x for x in self.log_nll_loss).item() #log_train_weights = sum(x for x in self.log_train_weights).item() avg_smooth_perp = log_train_loss / log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = log_nll_loss / log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.epoch_loss += log_train_loss self.epoch_nll_loss += log_nll_loss self.epoch_weights += log_all_weights acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / batch avg_grad_norm = sum(self.log_grad_norms) / len(self.log_grad_norms) #median_grad_norm = sorted(self.log_grad_norms)[len(self.log_grad_norms)//2] est_percent = int(100 * batch / self.est_batches) epoch_len = max(5, ut.get_num_digits(self.config['max_epochs'])) batch_len = max(5, ut.get_num_digits(self.est_batches)) if batch > self.est_batches: remaining = '?' else: remaining = ut.format_time(acc_speed_time * (self.est_batches - batch)) self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] cells = [ f'{epoch:{epoch_len}}', f'{batch:{batch_len}}', f'{est_percent:3}%', f'{remaining:>9}', f'{acc_speed_word:#10.4g}', f'{acc_speed_time:#6.4g}s', f'{avg_smooth_perp:#11.4g}', f'{avg_true_perp:#9.4g}', f'{avg_grad_norm:#9.4g}' ] self.logger.info(' '.join(cells)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.total_batches + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.FIXED_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps else: lr = max(min_lr, peak_lr * warmup_steps**(0.5) * step**(-0.5)) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.UPFLAT_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps for p in self.optimizer.param_groups: p['lr'] = lr else: pass def train(self): self.model.train() stop_early = False early_stop_msg_num = self.config[ 'early_stop_patience'] * self.validate_freq early_stop_msg_metric = 'epochs' if self.config[ 'val_by_bleu'] else 'batches' early_stop_msg = f'No improvement for last {early_stop_msg_num} {early_stop_msg_metric}; stopping early!' for epoch in range(1, self.config['max_epochs'] + 1): batch = 0 for batch_data in self.model.data_manager.get_batches( mode=ac.TRAINING, num_preload=self.num_preload): if batch == 0: self.logger.info(f'Begin epoch {epoch}') epoch_str = ' ' * max( 0, ut.get_num_digits(self.config['max_epochs']) - 5) + 'epoch' batch_str = ' ' * max( 0, ut.get_num_digits(self.est_batches) - 5) + 'batch' self.logger.info(' '.join([ epoch_str, batch_str, 'est%', 'remaining', 'trg word/s', 's/batch', 'smooth perp', 'true perp', 'grad norm' ])) batch += 1 self.run_log(batch, epoch, batch_data) if not self.config['val_per_epoch']: stop_early = self.maybe_validate() if stop_early: self.logger.info(early_stop_msg) break if stop_early: break self.report_epoch(epoch, batch) if self.config['val_per_epoch'] and epoch % self.validate_freq == 0: stop_early = self.maybe_validate(just_validate=True) if stop_early: self.logger.info(early_stop_msg) break if not self.config['val_by_bleu'] and not stop_early: # validate 1 last time self.maybe_validate(just_validate=True) self.logger.info('Training finished') self.logger.info('Train smooth perps:') self.logger.info(', '.join( [f'{x:.2f}' for x in self.train_smooth_perps])) self.logger.info('Train true perps:') self.logger.info(', '.join([f'{x:.2f}' for x in self.train_true_perps])) numpy.save( os.path.join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save( os.path.join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.model.save() # Evaluate test test_file = self.model.data_manager.data_files[ac.TESTING][ self.model.data_manager.src_lang] dev_file = self.model.data_manager.data_files[ac.VALIDATING][ self.model.data_manager.src_lang] if os.path.exists(test_file): self.logger.info('Evaluate test') self.restart_to_best_checkpoint() self.model.save() self.validator.translate(test_file, to_ids=True) self.logger.info('Translate dev set') self.validator.translate(dev_file, to_ids=True) def restart_to_best_checkpoint(self): if self.config['val_by_bleu']: best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) else: best_perp = numpy.min(self.validator.best_perps) best_cpkt_path = self.validator.get_cpkt_path(best_perp) self.logger.info(f'Restore best cpkt from {best_cpkt_path}') self.model.load_state_dict(torch.load(best_cpkt_path)) def is_patience_exhausted(self, patience, if_worst=False): ''' if_worst=False (default) -> check if last patience epochs have failed to improve dev score if_worst=True -> check if last epoch was WORSE than the patience epochs before it ''' curve = self.validator.bleu_curve if self.config[ 'val_by_bleu'] else self.validator.perp_curve best_worse = max if self.config['val_by_bleu'] is not if_worst else min return patience and len( curve) > patience and curve[-1 if if_worst else -1 - patience] == best_worse( curve[-1 - patience:]) def maybe_validate(self, just_validate=False): if self.total_batches % self.validate_freq == 0 or just_validate: self.model.save() self.validator.validate_and_save() # if doing annealing step = self.total_batches + 1.0 warmup_steps = self.config['warmup_steps'] if self.config['warmup_style'] == ac.NO_WARMUP \ or (self.config['warmup_style'] == ac.UPFLAT_WARMUP and step >= warmup_steps) \ and self.config['lr_decay'] > 0: if self.is_patience_exhausted(self.config['lr_decay_patience'], if_worst=True): if self.config['val_by_bleu']: metric = 'bleu' scores = self.validator.bleu_curve else: metric = 'perp' scores = self.validator.perp_curve scores = ', '.join([ str(x) for x in scores[-1 - self.config['lr_decay_patience']:] ]) self.logger.info(f'Past {metric} scores are {scores}') # when don't use warmup, decay lr if dev not improve if self.lr * self.config['lr_decay'] >= self.config[ 'min_lr']: new_lr = self.lr * self.config['lr_decay'] self.logger.info( f'Anneal the learning rate from {self.lr} to {new_lr}' ) self.lr = new_lr for p in self.optimizer.param_groups: p['lr'] = self.lr return self.is_patience_exhausted(self.config['early_stop_patience'])
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # model self.model = model self.pad_index = self.model.pad_index self.bos_index = self.model.bos_index self._log_parameters_list() # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.loss = XentLoss(pad_index=self.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens"]: raise ConfigurationError("Invalid normalization. " "Valid options: 'batch', 'tokens'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in ['bleu', 'chrf']: raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: 'bleu', 'chrf'.") self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # if we schedule after BLEU/chrf, we want to maximize it, else minimize # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric if self.early_stopping_metric in ["ppl", "loss"]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in ["bleu", "chrf"]: self.minimize_metric = False else: # eval metric that has to get minimized (not yet implemented) self.minimize_metric = True else: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', 'eval_metric'.") # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"]) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf # comparison function for scores self.is_best = lambda score: score < self.best_ckpt_score \ if self.minimize_metric else score > self.best_ckpt_score # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) self.init_from_checkpoint(model_load_path)
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( ' smoothed train perplexity: {}'.format(train_smooth_perp)) self.logger.info( ' true train perplexity: {}'.format(train_true_perp)) def run_log(self, b, e, batch_data): start = time.time() src_toks, trg_toks, targets = batch_data src_toks_cuda = src_toks.to(self.device) trg_toks_cuda = trg_toks.to(self.device) targets_cuda = targets.to(self.device) # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.normalize_loss == ac.LOSS_TOK: opt_loss = loss / (targets_cuda != ac.PAD_ID).type( loss.type()).sum() elif self.normalize_loss == ac.LOSS_BATCH: opt_loss = loss / targets_cuda.size()[0].type(loss.type()) else: opt_loss = loss opt_loss.backward() # clip gradient global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip']) # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().numpy().sum() loss = loss.cpu().detach().numpy() nll_loss = nll_loss.cpu().detach().numpy() self.num_batches_done += 1 self.log_train_loss += loss self.log_nll_loss += nll_loss self.log_train_weights += num_words self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_nll_loss += nll_loss self.epoch_weights += num_words self.epoch_time += time.time() - start if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_smooth_perp = self.log_train_loss / self.log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = self.log_nll_loss / self.log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.log_train_loss = 0. self.log_nll_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg smooth perp: {0:.2f}'.format(avg_smooth_perp)) self.logger.info( ' avg true perp: {0:.2f}'.format(avg_true_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) self.logger.info(' global norm: {0:.2f}'.format(global_norm)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.num_batches_done + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr def train(self): self.model.train() train_ids_file = self.data_manager.data_files['ids'] for e in range(self.max_epochs): b = 0 for batch_data in self.data_manager.get_batch( ids_file=train_ids_file, shuffle=True, num_preload=self.num_preload): b += 1 self.run_log(b, e, batch_data) if not self.val_per_epoch: self.maybe_validate() self.report_epoch(e + 1) if self.val_per_epoch and (e + 1) % self.validate_freq == 0: self.maybe_validate(just_validate=True) # validate 1 last time if not self.config['val_per_epoch']: self.maybe_validate(just_validate=True) self.logger.info('It is finally done, mate!') self.logger.info('Train smoothed perps:') self.logger.info(', '.join(map(str, self.train_smooth_perps))) self.logger.info('Train true perps:') self.logger.info(', '.join(map(str, self.train_true_perps))) numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save(join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.logger.info('Save final checkpoint') self.save_checkpoint() # Evaluate on test for checkpoint in self.data_manager.checkpoints: self.logger.info('Translate for {}'.format(checkpoint)) dev_file = self.data_manager.dev_files[checkpoint][ self.data_manager.src_lang] test_file = self.data_manager.test_files[checkpoint][ self.data_manager.src_lang] if exists(test_file): self.logger.info(' Evaluate on test') self.restart_to_best_checkpoint(checkpoint) self.validator.translate(self.model, test_file) self.logger.info(' Also translate dev') self.validator.translate(self.model, dev_file) def save_checkpoint(self): cpkt_path = join(self.config['save_to'], '{}.pth'.format(self.config['model_name'])) torch.save(self.model.state_dict(), cpkt_path) def restart_to_best_checkpoint(self, checkpoint): best_perp = numpy.min(self.validator.best_perps[checkpoint]) best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp) self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path)) self.model.load_state_dict(torch.load(best_cpkt_path)) def maybe_validate(self, just_validate=False): if self.num_batches_done % self.validate_freq == 0 or just_validate: self.save_checkpoint() self.validator.validate_and_save(self.model) # if doing annealing if self.config[ 'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0: cond = len( self.validator.perp_curve ) > self.patience and self.validator.perp_curve[-1] > max( self.validator.perp_curve[-1 - self.patience:-1]) if cond: metric = 'perp' scores = self.validator.perp_curve[-1 - self.patience:] scores = map(str, list(scores)) scores = ', '.join(scores) self.logger.info('Past {} are {}'.format(metric, scores)) # when don't use warmup, decay lr if dev not improve if self.lr * self.lr_decay >= self.config['min_lr']: self.logger.info( 'Anneal the learning rate from {} to {}'.format( self.lr, self.lr * self.lr_decay)) self.lr = self.lr * self.lr_decay for p in self.optimizer.param_groups: p['lr'] = self.lr