def __init__(self, args): super(Extractor, self).__init__() config = getattr(configurations, args.proto)() self.logger = ut.get_logger(config['log_file']) self.model_file = args.model_file var_list = args.var_list save_to = args.save_to if var_list is None: raise ValueError('Empty var list') if self.model_file is None or not os.path.exists(self.model_file): raise ValueError('Input file or model file does not exist') if not os.path.exists(save_to): os.makedirs(save_to) self.logger.info('Extracting these vars: {}'.format( ', '.join(var_list))) model = Model(config) model.load_state_dict(torch.load(self.model_file)) var_values = operator.attrgetter(*var_list)(model) if len(var_list) == 1: var_values = [var_values] for var, var_value in zip(var_list, var_values): var_path = os.path.join(save_to, var + '.npy') numpy.save(var_path, var_value.numpy())
def translate(self): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = Model(self.config).to(device) self.logger.info('Restore model from {}'.format(self.model_file)) model.load_state_dict(torch.load(self.model_file)) model.eval() best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() num_sents = 0 with open(self.input_file, 'r') as f: for line in f: if line.strip(): num_sents += 1 all_best_trans = [''] * num_sents all_beam_trans = [''] * num_sents with torch.no_grad(): self.logger.info('Start translating {}'.format(self.input_file)) start = time.time() count = 0 for (src_toks, original_idxs) in self.data_manager.get_trans_input( self.input_file): src_toks_cuda = src_toks.to(device) rets = model.beam_decode(src_toks_cuda) for i, ret in enumerate(rets): probs = ret['probs'].cpu().detach().numpy().reshape([-1]) scores = ret['scores'].cpu().detach().numpy().reshape([-1]) symbols = ret['symbols'].cpu().detach().numpy() best_trans, best_trans_ids, beam_trans = self.get_trans( probs, scores, symbols) all_best_trans[original_idxs[i]] = best_trans + '\n' all_beam_trans[original_idxs[i]] = beam_trans + '\n\n' count += 1 if count % 100 == 0: self.logger.info( ' Translating line {}, average {} seconds/sent'. format(count, (time.time() - start) / count)) model.train() with open(best_trans_file, 'w') as ftrans, open(beam_trans_file, 'w') as btrans: ftrans.write(''.join(all_best_trans)) btrans.write(''.join(all_beam_trans)) self.logger.info('Done translating {}, it takes {} minutes'.format( self.input_file, float(time.time() - start) / 60.0))
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.num_preload = args.num_preload self.lr = self.config['lr'] ut.remove_files_in_dir(self.config['save_to']) self.logger = ut.get_logger(self.config['log_file']) self.train_smooth_perps = [] self.train_true_perps = [] # For logging self.log_freq = self.config[ 'log_freq'] # log train stat every this-many batches self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] self.total_batches = 0 # number of batches done for the whole training self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model device = ut.get_device() self.model = Model(self.config).to(device) self.validator = Validator(self.config, self.model) self.validate_freq = self.config['validate_freq'] if self.validate_freq == 1: self.logger.info('Evaluate every ' + ( 'epoch' if self.config['val_per_epoch'] else 'batch')) else: self.logger.info(f'Evaluate every {self.validate_freq:,} ' + ( 'epochs' if self.config['val_per_epoch'] else 'batches')) # Estimated number of batches per epoch self.est_batches = max(self.model.data_manager.training_tok_counts ) // self.config['batch_size'] self.logger.info( f'Guessing around {self.est_batches:,} batches per epoch') param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info(f'Model has {int(param_count):,} parameters') # Set up parameter-specific options params = [] for p in self.model.parameters(): ptr = p.data_ptr() d = {'params': [p]} if ptr in self.model.parameter_attrs: attrs = self.model.parameter_attrs[ptr] for k in attrs: d[k] = attrs[k] params.append(d) self.optimizer = torch.optim.Adam(params, lr=self.lr, betas=(self.config['beta1'], self.config['beta2']), eps=self.config['epsilon']) def report_epoch(self, epoch, batches): self.logger.info(f'Finished epoch {epoch}') self.logger.info(f' Took {ut.format_time(self.epoch_time)}') self.logger.info( f' avg words/sec {self.epoch_weights / self.epoch_time:.2f}') self.logger.info(f' avg sec/batch {self.epoch_time / batches:.2f}') self.logger.info(f' {batches} batches') if self.epoch_weights: train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights else: train_smooth_perp = float('inf') train_true_perp = float('inf') self.est_batches = batches self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( f' smooth, true perp: {float(train_smooth_perp):.2f}, {float(train_true_perp):.2f}' ) def clip_grad_values(self): """ Adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_value_ This is the same as torch.nn.utils.clip_grad_value_, except is also sets nan gradients to 0.0 """ parameters = self.model.parameters() clip_value = float(self.config['grad_clamp']) if isinstance(parameters, torch.Tensor): parameters = [parameters] for p in filter(lambda p: p.grad is not None, parameters): p.grad.data.clamp_(min=-clip_value, max=clip_value) p.grad.data[torch.isnan(p.grad.data)] = 0.0 def get_params(self, pe=False): for n, p in self.model.named_parameters(): if (n in self.model.struct_params) == pe: yield p def run_log(self, batch, epoch, batch_data): #with torch.autograd.detect_anomaly(): # throws exception when any forward computation produces nan start = time.time() _, src_toks, src_structs, trg_toks, targets = batch_data # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks, src_structs, trg_toks, targets, batch, epoch) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.config['normalize_loss'] == ac.LOSS_TOK: opt_loss = loss / (targets != ac.PAD_ID).sum() elif self.config['normalize_loss'] == ac.LOSS_BATCH: opt_loss = loss / targets.size()[0] else: opt_loss = loss opt_loss.backward() # clip gradient if self.config['grad_clamp']: self.clip_grad_values() if self.config['grad_clip_pe']: pms = list(self.get_params(True)) if pms: torch.nn.utils.clip_grad_norm_(pms, self.config['grad_clip_pe']) pms = self.get_params() else: pms = self.model.parameters() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['grad_clip']).detach() # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().sum() loss = loss.detach() nll_loss = nll_loss.detach() self.total_batches += 1 self.log_train_loss.append(loss) self.log_nll_loss.append(nll_loss) self.log_train_weights.append(num_words) self.log_grad_norms.append(grad_norm) self.epoch_time += time.time() - start if self.total_batches % self.log_freq == 0: log_train_loss = torch.tensor(0.0) log_nll_loss = torch.tensor(0.0) log_train_weights = torch.tensor(0.0) log_all_weights = torch.tensor(0.0) for smooth, nll, weight in zip(self.log_train_loss, self.log_nll_loss, self.log_train_weights): if not self.config['grad_clamp'] or (torch.isfinite(smooth) and torch.isfinite(nll)): log_train_loss += smooth log_nll_loss += nll log_train_weights += weight log_all_weights += weight #log_train_loss = sum(x for x in self.log_train_loss).item() #log_nll_loss = sum(x for x in self.log_nll_loss).item() #log_train_weights = sum(x for x in self.log_train_weights).item() avg_smooth_perp = log_train_loss / log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = log_nll_loss / log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.epoch_loss += log_train_loss self.epoch_nll_loss += log_nll_loss self.epoch_weights += log_all_weights acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / batch avg_grad_norm = sum(self.log_grad_norms) / len(self.log_grad_norms) #median_grad_norm = sorted(self.log_grad_norms)[len(self.log_grad_norms)//2] est_percent = int(100 * batch / self.est_batches) epoch_len = max(5, ut.get_num_digits(self.config['max_epochs'])) batch_len = max(5, ut.get_num_digits(self.est_batches)) if batch > self.est_batches: remaining = '?' else: remaining = ut.format_time(acc_speed_time * (self.est_batches - batch)) self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] cells = [ f'{epoch:{epoch_len}}', f'{batch:{batch_len}}', f'{est_percent:3}%', f'{remaining:>9}', f'{acc_speed_word:#10.4g}', f'{acc_speed_time:#6.4g}s', f'{avg_smooth_perp:#11.4g}', f'{avg_true_perp:#9.4g}', f'{avg_grad_norm:#9.4g}' ] self.logger.info(' '.join(cells)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.total_batches + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.FIXED_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps else: lr = max(min_lr, peak_lr * warmup_steps**(0.5) * step**(-0.5)) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.UPFLAT_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps for p in self.optimizer.param_groups: p['lr'] = lr else: pass def train(self): self.model.train() stop_early = False early_stop_msg_num = self.config[ 'early_stop_patience'] * self.validate_freq early_stop_msg_metric = 'epochs' if self.config[ 'val_by_bleu'] else 'batches' early_stop_msg = f'No improvement for last {early_stop_msg_num} {early_stop_msg_metric}; stopping early!' for epoch in range(1, self.config['max_epochs'] + 1): batch = 0 for batch_data in self.model.data_manager.get_batches( mode=ac.TRAINING, num_preload=self.num_preload): if batch == 0: self.logger.info(f'Begin epoch {epoch}') epoch_str = ' ' * max( 0, ut.get_num_digits(self.config['max_epochs']) - 5) + 'epoch' batch_str = ' ' * max( 0, ut.get_num_digits(self.est_batches) - 5) + 'batch' self.logger.info(' '.join([ epoch_str, batch_str, 'est%', 'remaining', 'trg word/s', 's/batch', 'smooth perp', 'true perp', 'grad norm' ])) batch += 1 self.run_log(batch, epoch, batch_data) if not self.config['val_per_epoch']: stop_early = self.maybe_validate() if stop_early: self.logger.info(early_stop_msg) break if stop_early: break self.report_epoch(epoch, batch) if self.config['val_per_epoch'] and epoch % self.validate_freq == 0: stop_early = self.maybe_validate(just_validate=True) if stop_early: self.logger.info(early_stop_msg) break if not self.config['val_by_bleu'] and not stop_early: # validate 1 last time self.maybe_validate(just_validate=True) self.logger.info('Training finished') self.logger.info('Train smooth perps:') self.logger.info(', '.join( [f'{x:.2f}' for x in self.train_smooth_perps])) self.logger.info('Train true perps:') self.logger.info(', '.join([f'{x:.2f}' for x in self.train_true_perps])) numpy.save( os.path.join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save( os.path.join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.model.save() # Evaluate test test_file = self.model.data_manager.data_files[ac.TESTING][ self.model.data_manager.src_lang] dev_file = self.model.data_manager.data_files[ac.VALIDATING][ self.model.data_manager.src_lang] if os.path.exists(test_file): self.logger.info('Evaluate test') self.restart_to_best_checkpoint() self.model.save() self.validator.translate(test_file, to_ids=True) self.logger.info('Translate dev set') self.validator.translate(dev_file, to_ids=True) def restart_to_best_checkpoint(self): if self.config['val_by_bleu']: best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) else: best_perp = numpy.min(self.validator.best_perps) best_cpkt_path = self.validator.get_cpkt_path(best_perp) self.logger.info(f'Restore best cpkt from {best_cpkt_path}') self.model.load_state_dict(torch.load(best_cpkt_path)) def is_patience_exhausted(self, patience, if_worst=False): ''' if_worst=False (default) -> check if last patience epochs have failed to improve dev score if_worst=True -> check if last epoch was WORSE than the patience epochs before it ''' curve = self.validator.bleu_curve if self.config[ 'val_by_bleu'] else self.validator.perp_curve best_worse = max if self.config['val_by_bleu'] is not if_worst else min return patience and len( curve) > patience and curve[-1 if if_worst else -1 - patience] == best_worse( curve[-1 - patience:]) def maybe_validate(self, just_validate=False): if self.total_batches % self.validate_freq == 0 or just_validate: self.model.save() self.validator.validate_and_save() # if doing annealing step = self.total_batches + 1.0 warmup_steps = self.config['warmup_steps'] if self.config['warmup_style'] == ac.NO_WARMUP \ or (self.config['warmup_style'] == ac.UPFLAT_WARMUP and step >= warmup_steps) \ and self.config['lr_decay'] > 0: if self.is_patience_exhausted(self.config['lr_decay_patience'], if_worst=True): if self.config['val_by_bleu']: metric = 'bleu' scores = self.validator.bleu_curve else: metric = 'perp' scores = self.validator.perp_curve scores = ', '.join([ str(x) for x in scores[-1 - self.config['lr_decay_patience']:] ]) self.logger.info(f'Past {metric} scores are {scores}') # when don't use warmup, decay lr if dev not improve if self.lr * self.config['lr_decay'] >= self.config[ 'min_lr']: new_lr = self.lr * self.config['lr_decay'] self.logger.info( f'Anneal the learning rate from {self.lr} to {new_lr}' ) self.lr = new_lr for p in self.optimizer.param_groups: p['lr'] = self.lr return self.is_patience_exhausted(self.config['early_stop_patience'])
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( ' smoothed train perplexity: {}'.format(train_smooth_perp)) self.logger.info( ' true train perplexity: {}'.format(train_true_perp)) def run_log(self, b, e, batch_data): start = time.time() src_toks, trg_toks, targets = batch_data src_toks_cuda = src_toks.to(self.device) trg_toks_cuda = trg_toks.to(self.device) targets_cuda = targets.to(self.device) # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.normalize_loss == ac.LOSS_TOK: opt_loss = loss / (targets_cuda != ac.PAD_ID).type( loss.type()).sum() elif self.normalize_loss == ac.LOSS_BATCH: opt_loss = loss / targets_cuda.size()[0].type(loss.type()) else: opt_loss = loss opt_loss.backward() # clip gradient global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip']) # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().numpy().sum() loss = loss.cpu().detach().numpy() nll_loss = nll_loss.cpu().detach().numpy() self.num_batches_done += 1 self.log_train_loss += loss self.log_nll_loss += nll_loss self.log_train_weights += num_words self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_nll_loss += nll_loss self.epoch_weights += num_words self.epoch_time += time.time() - start if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_smooth_perp = self.log_train_loss / self.log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = self.log_nll_loss / self.log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.log_train_loss = 0. self.log_nll_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg smooth perp: {0:.2f}'.format(avg_smooth_perp)) self.logger.info( ' avg true perp: {0:.2f}'.format(avg_true_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) self.logger.info(' global norm: {0:.2f}'.format(global_norm)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.num_batches_done + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr def train(self): self.model.train() train_ids_file = self.data_manager.data_files['ids'] for e in range(self.max_epochs): b = 0 for batch_data in self.data_manager.get_batch( ids_file=train_ids_file, shuffle=True, num_preload=self.num_preload): b += 1 self.run_log(b, e, batch_data) if not self.val_per_epoch: self.maybe_validate() self.report_epoch(e + 1) if self.val_per_epoch and (e + 1) % self.validate_freq == 0: self.maybe_validate(just_validate=True) # validate 1 last time if not self.config['val_per_epoch']: self.maybe_validate(just_validate=True) self.logger.info('It is finally done, mate!') self.logger.info('Train smoothed perps:') self.logger.info(', '.join(map(str, self.train_smooth_perps))) self.logger.info('Train true perps:') self.logger.info(', '.join(map(str, self.train_true_perps))) numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save(join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.logger.info('Save final checkpoint') self.save_checkpoint() # Evaluate on test for checkpoint in self.data_manager.checkpoints: self.logger.info('Translate for {}'.format(checkpoint)) dev_file = self.data_manager.dev_files[checkpoint][ self.data_manager.src_lang] test_file = self.data_manager.test_files[checkpoint][ self.data_manager.src_lang] if exists(test_file): self.logger.info(' Evaluate on test') self.restart_to_best_checkpoint(checkpoint) self.validator.translate(self.model, test_file) self.logger.info(' Also translate dev') self.validator.translate(self.model, dev_file) def save_checkpoint(self): cpkt_path = join(self.config['save_to'], '{}.pth'.format(self.config['model_name'])) torch.save(self.model.state_dict(), cpkt_path) def restart_to_best_checkpoint(self, checkpoint): best_perp = numpy.min(self.validator.best_perps[checkpoint]) best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp) self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path)) self.model.load_state_dict(torch.load(best_cpkt_path)) def maybe_validate(self, just_validate=False): if self.num_batches_done % self.validate_freq == 0 or just_validate: self.save_checkpoint() self.validator.validate_and_save(self.model) # if doing annealing if self.config[ 'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0: cond = len( self.validator.perp_curve ) > self.patience and self.validator.perp_curve[-1] > max( self.validator.perp_curve[-1 - self.patience:-1]) if cond: metric = 'perp' scores = self.validator.perp_curve[-1 - self.patience:] scores = map(str, list(scores)) scores = ', '.join(scores) self.logger.info('Past {} are {}'.format(metric, scores)) # when don't use warmup, decay lr if dev not improve if self.lr * self.lr_decay >= self.config['min_lr']: self.logger.info( 'Anneal the learning rate from {} to {}'.format( self.lr, self.lr * self.lr_decay)) self.lr = self.lr * self.lr_decay for p in self.optimizer.param_groups: p['lr'] = self.lr