class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( ' smoothed train perplexity: {}'.format(train_smooth_perp)) self.logger.info( ' true train perplexity: {}'.format(train_true_perp)) def run_log(self, b, e, batch_data): start = time.time() src_toks, trg_toks, targets = batch_data src_toks_cuda = src_toks.to(self.device) trg_toks_cuda = trg_toks.to(self.device) targets_cuda = targets.to(self.device) # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.normalize_loss == ac.LOSS_TOK: opt_loss = loss / (targets_cuda != ac.PAD_ID).type( loss.type()).sum() elif self.normalize_loss == ac.LOSS_BATCH: opt_loss = loss / targets_cuda.size()[0].type(loss.type()) else: opt_loss = loss opt_loss.backward() # clip gradient global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip']) # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().numpy().sum() loss = loss.cpu().detach().numpy() nll_loss = nll_loss.cpu().detach().numpy() self.num_batches_done += 1 self.log_train_loss += loss self.log_nll_loss += nll_loss self.log_train_weights += num_words self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_nll_loss += nll_loss self.epoch_weights += num_words self.epoch_time += time.time() - start if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_smooth_perp = self.log_train_loss / self.log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = self.log_nll_loss / self.log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.log_train_loss = 0. self.log_nll_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg smooth perp: {0:.2f}'.format(avg_smooth_perp)) self.logger.info( ' avg true perp: {0:.2f}'.format(avg_true_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) self.logger.info(' global norm: {0:.2f}'.format(global_norm)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.num_batches_done + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr def train(self): self.model.train() train_ids_file = self.data_manager.data_files['ids'] for e in range(self.max_epochs): b = 0 for batch_data in self.data_manager.get_batch( ids_file=train_ids_file, shuffle=True, num_preload=self.num_preload): b += 1 self.run_log(b, e, batch_data) if not self.val_per_epoch: self.maybe_validate() self.report_epoch(e + 1) if self.val_per_epoch and (e + 1) % self.validate_freq == 0: self.maybe_validate(just_validate=True) # validate 1 last time if not self.config['val_per_epoch']: self.maybe_validate(just_validate=True) self.logger.info('It is finally done, mate!') self.logger.info('Train smoothed perps:') self.logger.info(', '.join(map(str, self.train_smooth_perps))) self.logger.info('Train true perps:') self.logger.info(', '.join(map(str, self.train_true_perps))) numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save(join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.logger.info('Save final checkpoint') self.save_checkpoint() # Evaluate on test for checkpoint in self.data_manager.checkpoints: self.logger.info('Translate for {}'.format(checkpoint)) dev_file = self.data_manager.dev_files[checkpoint][ self.data_manager.src_lang] test_file = self.data_manager.test_files[checkpoint][ self.data_manager.src_lang] if exists(test_file): self.logger.info(' Evaluate on test') self.restart_to_best_checkpoint(checkpoint) self.validator.translate(self.model, test_file) self.logger.info(' Also translate dev') self.validator.translate(self.model, dev_file) def save_checkpoint(self): cpkt_path = join(self.config['save_to'], '{}.pth'.format(self.config['model_name'])) torch.save(self.model.state_dict(), cpkt_path) def restart_to_best_checkpoint(self, checkpoint): best_perp = numpy.min(self.validator.best_perps[checkpoint]) best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp) self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path)) self.model.load_state_dict(torch.load(best_cpkt_path)) def maybe_validate(self, just_validate=False): if self.num_batches_done % self.validate_freq == 0 or just_validate: self.save_checkpoint() self.validator.validate_and_save(self.model) # if doing annealing if self.config[ 'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0: cond = len( self.validator.perp_curve ) > self.patience and self.validator.perp_curve[-1] > max( self.validator.perp_curve[-1 - self.patience:-1]) if cond: metric = 'perp' scores = self.validator.perp_curve[-1 - self.patience:] scores = map(str, list(scores)) scores = ', '.join(scores) self.logger.info('Past {} are {}'.format(metric, scores)) # when don't use warmup, decay lr if dev not improve if self.lr * self.lr_decay >= self.config['min_lr']: self.logger.info( 'Anneal the learning rate from {} to {}'.format( self.lr, self.lr * self.lr_decay)) self.lr = self.lr * self.lr_decay for p in self.optimizer.param_groups: p['lr'] = self.lr
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.config['fixed_var_list'] = args.fixed_var_list self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.lr = self.config['lr'] self.max_epochs = self.config['max_epochs'] self.save_freq = self.config['save_freq'] self.cpkt_path = None self.validate_freq = None self.train_perps = [] self.saver = None self.train_m = None self.dev_m = None self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.validate_freq = ut.get_validation_frequency( self.data_manager.length_files[ac.TRAINING], self.config['validate_freq'], self.config['batch_size']) self.logger.info('Evaluate every {} batches'.format( self.validate_freq)) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d, seed=ac.SEED) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def reload_and_get_cpkt_saver(self, config, sess): cpkt_path = join(config['save_to'], '{}.cpkt'.format(config['model_name'])) saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=config['n_best'] + 1) # TF some version >= 0.11, no longer .cpkt so check for meta file instead if exists(cpkt_path + '.meta') and config['reload']: self.logger.info('Reload model from {}'.format(cpkt_path)) saver.restore(sess, cpkt_path) self.cpkt_path = cpkt_path self.saver = saver def train(self): tf.reset_default_graph() with tf.Graph().as_default(): tf.set_random_seed(ac.SEED) self.train_m = self.get_model(ac.TRAINING) self.dev_m = self.get_model(ac.VALIDATING) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.reload_and_get_cpkt_saver(self.config, sess) self.logger.info('Set learning rate to {}'.format(self.lr)) sess.run(tf.assign(self.train_m.lr, self.lr)) for e in xrange(self.max_epochs): for b, batch_data in self.data_manager.get_batch( mode=ac.TRAINING, num_batches=self.num_preload): self.run_log_save(sess, b, e, batch_data) self.maybe_validate(sess) self.report_epoch(e) self.logger.info('It is finally done, mate!') self.logger.info('Train perplexities:') self.logger.info(', '.join(map(str, self.train_perps))) numpy.save(join(self.config['save_to'], 'train_perps.npy'), self.train_perps) self.logger.info('Save final checkpoint') self.saver.save(sess, self.cpkt_path) # Evaluate on test if exists(self.data_manager.ids_files[ac.TESTING]): self.logger.info('Evaluate on test') best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) self.logger.info( 'Restore best cpkt from {}'.format(best_cpkt_path)) self.saver.restore(sess, best_cpkt_path) self.validator.evaluate(sess, self.dev_m, mode=ac.TESTING) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e + 1)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_perp = self.epoch_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_perp = numpy.exp(train_perp) if train_perp < 300 else float( 'inf') self.train_perps.append(train_perp) self.logger.info(' train perplexity: {}'.format(train_perp)) def sample_input(self, batch_data): # TODO: more on sampling src_sent = batch_data[0][0] src_length = batch_data[1][0] trg_sent = batch_data[2][0] target = batch_data[3][0] weight = batch_data[4][0] src_sent = map(self.src_ivocab.get, src_sent) src_sent = u' '.join(src_sent) trg_sent = map(self.trg_ivocab.get, trg_sent) trg_sent = u' '.join(trg_sent) target = map(self.trg_ivocab.get, target) target = u' '.join(target) weight = ' '.join(map(str, weight)) self.logger.info('Sample input data:') self.logger.info(u'Src: {}'.format(src_sent)) self.logger.info(u'Src len: {}'.format(src_length)) self.logger.info(u'Trg: {}'.format(trg_sent)) self.logger.info(u'Tar: {}'.format(target)) self.logger.info(u'W: {}'.format(weight)) def run_log_save(self, sess, b, e, batch_data): start = time.time() src_inputs, src_seq_lengths, trg_inputs, trg_targets, target_weights = batch_data feed = { self.train_m.src_inputs: src_inputs, self.train_m.src_seq_lengths: src_seq_lengths, self.train_m.trg_inputs: trg_inputs, self.train_m.trg_targets: trg_targets, self.train_m.target_weights: target_weights } loss, _ = sess.run([self.train_m.loss, self.train_m.train_op], feed) num_trg_words = numpy.sum(target_weights) self.num_batches_done += 1 self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_weights += num_trg_words self.log_train_loss += loss self.log_train_weights += num_trg_words self.epoch_time += time.time() - start if self.num_batches_done % (10 * self.log_freq) == 0: self.sample_input(batch_data) if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_word_perp = self.log_train_loss / self.log_train_weights avg_word_perp = numpy.exp( avg_word_perp) if avg_word_perp < 300 else float('inf') self.log_train_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg word perp: {0:.2f}'.format(avg_word_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) if self.num_batches_done % self.save_freq == 0: start = time.time() self.saver.save(sess, self.cpkt_path) self.logger.info('Save model to {}, takes {}'.format( self.cpkt_path, ut.format_seconds(time.time() - start))) def maybe_validate(self, sess): if self.num_batches_done % self.validate_freq == 0: self.validator.validate_and_save(sess, self.dev_m, self.saver) def _call_me_maybe(self): pass # NO