class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.config['fixed_var_list'] = args.fixed_var_list self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.lr = self.config['lr'] self.max_epochs = self.config['max_epochs'] self.save_freq = self.config['save_freq'] self.cpkt_path = None self.validate_freq = None self.train_perps = [] self.saver = None self.train_m = None self.dev_m = None self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.validate_freq = ut.get_validation_frequency( self.data_manager.length_files[ac.TRAINING], self.config['validate_freq'], self.config['batch_size']) self.logger.info('Evaluate every {} batches'.format( self.validate_freq)) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d, seed=ac.SEED) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def reload_and_get_cpkt_saver(self, config, sess): cpkt_path = join(config['save_to'], '{}.cpkt'.format(config['model_name'])) saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=config['n_best'] + 1) # TF some version >= 0.11, no longer .cpkt so check for meta file instead if exists(cpkt_path + '.meta') and config['reload']: self.logger.info('Reload model from {}'.format(cpkt_path)) saver.restore(sess, cpkt_path) self.cpkt_path = cpkt_path self.saver = saver def train(self): tf.reset_default_graph() with tf.Graph().as_default(): tf.set_random_seed(ac.SEED) self.train_m = self.get_model(ac.TRAINING) self.dev_m = self.get_model(ac.VALIDATING) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.reload_and_get_cpkt_saver(self.config, sess) self.logger.info('Set learning rate to {}'.format(self.lr)) sess.run(tf.assign(self.train_m.lr, self.lr)) for e in xrange(self.max_epochs): for b, batch_data in self.data_manager.get_batch( mode=ac.TRAINING, num_batches=self.num_preload): self.run_log_save(sess, b, e, batch_data) self.maybe_validate(sess) self.report_epoch(e) self.logger.info('It is finally done, mate!') self.logger.info('Train perplexities:') self.logger.info(', '.join(map(str, self.train_perps))) numpy.save(join(self.config['save_to'], 'train_perps.npy'), self.train_perps) self.logger.info('Save final checkpoint') self.saver.save(sess, self.cpkt_path) # Evaluate on test if exists(self.data_manager.ids_files[ac.TESTING]): self.logger.info('Evaluate on test') best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) self.logger.info( 'Restore best cpkt from {}'.format(best_cpkt_path)) self.saver.restore(sess, best_cpkt_path) self.validator.evaluate(sess, self.dev_m, mode=ac.TESTING) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e + 1)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_perp = self.epoch_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_perp = numpy.exp(train_perp) if train_perp < 300 else float( 'inf') self.train_perps.append(train_perp) self.logger.info(' train perplexity: {}'.format(train_perp)) def sample_input(self, batch_data): # TODO: more on sampling src_sent = batch_data[0][0] src_length = batch_data[1][0] trg_sent = batch_data[2][0] target = batch_data[3][0] weight = batch_data[4][0] src_sent = map(self.src_ivocab.get, src_sent) src_sent = u' '.join(src_sent) trg_sent = map(self.trg_ivocab.get, trg_sent) trg_sent = u' '.join(trg_sent) target = map(self.trg_ivocab.get, target) target = u' '.join(target) weight = ' '.join(map(str, weight)) self.logger.info('Sample input data:') self.logger.info(u'Src: {}'.format(src_sent)) self.logger.info(u'Src len: {}'.format(src_length)) self.logger.info(u'Trg: {}'.format(trg_sent)) self.logger.info(u'Tar: {}'.format(target)) self.logger.info(u'W: {}'.format(weight)) def run_log_save(self, sess, b, e, batch_data): start = time.time() src_inputs, src_seq_lengths, trg_inputs, trg_targets, target_weights = batch_data feed = { self.train_m.src_inputs: src_inputs, self.train_m.src_seq_lengths: src_seq_lengths, self.train_m.trg_inputs: trg_inputs, self.train_m.trg_targets: trg_targets, self.train_m.target_weights: target_weights } loss, _ = sess.run([self.train_m.loss, self.train_m.train_op], feed) num_trg_words = numpy.sum(target_weights) self.num_batches_done += 1 self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_weights += num_trg_words self.log_train_loss += loss self.log_train_weights += num_trg_words self.epoch_time += time.time() - start if self.num_batches_done % (10 * self.log_freq) == 0: self.sample_input(batch_data) if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_word_perp = self.log_train_loss / self.log_train_weights avg_word_perp = numpy.exp( avg_word_perp) if avg_word_perp < 300 else float('inf') self.log_train_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg word perp: {0:.2f}'.format(avg_word_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) if self.num_batches_done % self.save_freq == 0: start = time.time() self.saver.save(sess, self.cpkt_path) self.logger.info('Save model to {}, takes {}'.format( self.cpkt_path, ut.format_seconds(time.time() - start))) def maybe_validate(self, sess): if self.num_batches_done % self.validate_freq == 0: self.validator.validate_and_save(sess, self.dev_m, self.saver) def _call_me_maybe(self): pass # NO
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.reverse = self.config['reverse'] self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file self.plot_align = args.plot_align self.unk_repl = args.unk_repl if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang) self.translate() def ids_to_trans(self, trans_ids, trans_alignments, no_unk_src_toks): words = [] word_ids = [] # Could have done better but this is clearer to me if not self.unk_repl: for idx, word_idx in enumerate(trans_ids): if word_idx == ac.EOS_ID: break words.append(self.trg_ivocab[word_idx]) word_ids.append(word_idx) else: for idx, word_idx in enumerate(trans_ids): if word_idx == ac.EOS_ID: break if word_idx == ac.UNK_ID: # Replace UNK with higest attention source words alignment = trans_alignments[idx] highest_att_src_tok_pos = numpy.argmax(alignment) words.append(no_unk_src_toks[highest_att_src_tok_pos]) else: words.append(self.trg_ivocab[word_idx]) word_ids.append(word_idx) return u' '.join(words), word_ids def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def get_trans(self, probs, scores, symbols, parents, alignments, no_unk_src_toks): sorted_rows = numpy.argsort(scores[:, -1])[::-1] best_trans_alignments = [] best_trans = None best_tran_ids = None beam_trans = [] for i, r in enumerate(sorted_rows): row_idx = r col_idx = scores.shape[1] - 1 trans_ids = [] trans_alignments = [] while True: if col_idx < 0: break trans_ids.append(symbols[row_idx, col_idx]) align = alignments[row_idx, col_idx, :] trans_alignments.append(align) if i == 0: best_trans_alignments.append(align if not self.reverse else align[::-1]) row_idx = parents[row_idx, col_idx] col_idx -= 1 trans_ids = trans_ids[::-1] trans_alignments = trans_alignments[::-1] trans_out, trans_out_ids = self.ids_to_trans(trans_ids, trans_alignments, no_unk_src_toks) beam_trans.append(u'{} {:.2f} {:.2f}'.format(trans_out, scores[r, -1], probs[r, -1])) if i == 0: # highest prob trans best_trans = trans_out best_tran_ids = trans_out_ids return best_trans, best_tran_ids, u'\n'.join(beam_trans), best_trans_alignments[::-1] def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename): """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py Change the font in family param below. If the system font is not used, delete matplotlib font cache https://github.com/matplotlib/matplotlib/issues/3590 """ fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro') for xtick, idx in zip(ax.get_xticklabels(), source_ids): if idx == ac.UNK_ID: xtick.set_color('b') # target words -> row labels ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro') for ytick, idx in zip(ax.get_yticklabels(), target_ids): if idx == ac.UNK_ID: ytick.set_color('b') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(filename) plt.close('all') def translate(self): with tf.Graph().as_default(): train_model = self.get_model(ac.TRAINING) model = self.get_model(ac.VALIDATING) with tf.Session() as sess: self.logger.info('Restore model from {}'.format(self.model_file)) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, self.model_file) best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() ftrans = open(best_trans_file, 'w', 'utf-8') btrans = open(beam_trans_file, 'w', 'utf-8') self.logger.info('Start translating {}'.format(self.input_file)) start = time.time() count = 0 for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input(self.input_file): feed = { model.src_inputs: src_input, model.src_seq_lengths: src_seq_len } probs, scores, symbols, parents, alignments = sess.run([model.probs, model.scores, model.symbols, model.parents, model.alignments], feed_dict=feed) alignments = numpy.transpose(alignments, axes=(1, 0, 2)) probs = numpy.transpose(numpy.array(probs)) scores = numpy.transpose(numpy.array(scores)) symbols = numpy.transpose(numpy.array(symbols)) parents = numpy.transpose(numpy.array(parents)) best_trans, best_trans_ids, beam_trans, best_trans_alignments = self.get_trans(probs, scores, symbols, parents, alignments, no_unk_src_toks) ftrans.write(best_trans + '\n') btrans.write(beam_trans + '\n\n') if self.plot_align: src_input = numpy.reshape(src_input, [-1]) if self.reverse: src_input = src_input[::-1] no_unk_src_toks = no_unk_src_toks[::-1] trans_toks = best_trans.split() best_trans_alignments = numpy.array(best_trans_alignments)[:len(trans_toks)] filename = '{}_{}.png'.format(self.input_file, count) self.plot_head_map(best_trans_alignments, trans_toks, best_trans_ids, no_unk_src_toks, src_input, filename) count += 1 if count % 100 == 0: self.logger.info(' Translating line {}, average {} seconds/sent'.format(count, (time.time() - start) / count)) ftrans.close() btrans.close() self.logger.info('Done translating {}, it takes {} minutes'.format(self.input_file, float(time.time() - start) / 60.0))
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.reverse = self.config['reverse'] self.unk_repl = self.config['unk_repl'] self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file self.plot_align = args.plot_align if self.input_file is None or self.model_file is None or not os.path.exists( self.input_file) or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) self.translate() def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def translate(self): with tf.Graph().as_default(): train_model = self.get_model(ac.TRAINING) model = self.get_model(ac.VALIDATING) with tf.Session() as sess: self.logger.info('Restore model from {}'.format( self.model_file)) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, self.model_file) best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() ftrans = open(best_trans_file, 'w', 'utf-8') btrans = open(beam_trans_file, 'w', 'utf-8') self.logger.info('Start translating {}'.format( self.input_file)) start = time.time() count = 0 for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input( self.input_file): feed = { model.src_inputs: src_input, model.src_seq_lengths: src_seq_len } probs, scores, symbols, parents, alignments = sess.run( [ model.probs, model.scores, model.symbols, model.parents, model.alignments ], feed_dict=feed) alignments = numpy.transpose(alignments, axes=(1, 0, 2)) probs = numpy.transpose(numpy.array(probs)) scores = numpy.transpose(numpy.array(scores)) symbols = numpy.transpose(numpy.array(symbols)) parents = numpy.transpose(numpy.array(parents)) best_trans, best_trans_ids, beam_trans, best_trans_alignments = ut.get_trans( probs, scores, symbols, parents, alignments, no_unk_src_toks, self.trg_ivocab, reverse=self.reverse, unk_repl=self.unk_repl) best_trans_wo_eos = best_trans.split()[:-1] best_trans_wo_eos = u' '.join(best_trans_wo_eos) ftrans.write(best_trans_wo_eos + '\n') btrans.write(beam_trans + '\n\n') if self.plot_align: src_input = numpy.reshape(src_input, [-1]) if self.reverse: src_input = src_input[::-1] no_unk_src_toks = no_unk_src_toks[::-1] trans_toks = best_trans.split() best_trans_alignments = numpy.array( best_trans_alignments)[:len(trans_toks)] filename = '{}_{}.png'.format(self.input_file, count) ut.plot_head_map(best_trans_alignments, trans_toks, best_tran_ids, no_unk_src_toks, src_input, filename) count += 1 if count % 100 == 0: self.logger.info( ' Translating line {}, average {} seconds/sent'. format(count, (time.time() - start) / count)) ftrans.close() btrans.close() self.logger.info( 'Done translating {}, it takes {} minutes'.format( self.input_file, float(time.time() - start) / 60.0))