def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon)
def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) self.translate()
def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.reverse = self.config['reverse'] self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file self.plot_align = args.plot_align self.unk_repl = args.unk_repl if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang) self.translate()
def __init__(self, config, load_from=None): super(Model, self).__init__() self.config = config self.struct = self.config['struct'] self.decoder_mask = None self.data_manager = DataManager(config, init_vocab=(not load_from)) if load_from: self.load_state_dict(torch.load(load_from, map_location=ut.get_device()), do_init=True) else: self.init_embeddings() self.init_model() self.add_struct_params() # dict where keys are data_ptrs to dicts of parameter options # see https://pytorch.org/docs/stable/optim.html#per-parameter-options self.parameter_attrs = {}
def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.config['fixed_var_list'] = args.fixed_var_list self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.lr = self.config['lr'] self.max_epochs = self.config['max_epochs'] self.save_freq = self.config['save_freq'] self.cpkt_path = None self.validate_freq = None self.train_perps = [] self.saver = None self.train_m = None self.dev_m = None self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.validate_freq = ut.get_validation_frequency( self.data_manager.length_files[ac.TRAINING], self.config['validate_freq'], self.config['batch_size']) self.logger.info('Evaluate every {} batches'.format( self.validate_freq)) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.reverse = self.config['reverse'] self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file self.plot_align = args.plot_align self.unk_repl = args.unk_repl if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang) self.translate() def ids_to_trans(self, trans_ids, trans_alignments, no_unk_src_toks): words = [] word_ids = [] # Could have done better but this is clearer to me if not self.unk_repl: for idx, word_idx in enumerate(trans_ids): if word_idx == ac.EOS_ID: break words.append(self.trg_ivocab[word_idx]) word_ids.append(word_idx) else: for idx, word_idx in enumerate(trans_ids): if word_idx == ac.EOS_ID: break if word_idx == ac.UNK_ID: # Replace UNK with higest attention source words alignment = trans_alignments[idx] highest_att_src_tok_pos = numpy.argmax(alignment) words.append(no_unk_src_toks[highest_att_src_tok_pos]) else: words.append(self.trg_ivocab[word_idx]) word_ids.append(word_idx) return u' '.join(words), word_ids def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def get_trans(self, probs, scores, symbols, parents, alignments, no_unk_src_toks): sorted_rows = numpy.argsort(scores[:, -1])[::-1] best_trans_alignments = [] best_trans = None best_tran_ids = None beam_trans = [] for i, r in enumerate(sorted_rows): row_idx = r col_idx = scores.shape[1] - 1 trans_ids = [] trans_alignments = [] while True: if col_idx < 0: break trans_ids.append(symbols[row_idx, col_idx]) align = alignments[row_idx, col_idx, :] trans_alignments.append(align) if i == 0: best_trans_alignments.append(align if not self.reverse else align[::-1]) row_idx = parents[row_idx, col_idx] col_idx -= 1 trans_ids = trans_ids[::-1] trans_alignments = trans_alignments[::-1] trans_out, trans_out_ids = self.ids_to_trans(trans_ids, trans_alignments, no_unk_src_toks) beam_trans.append(u'{} {:.2f} {:.2f}'.format(trans_out, scores[r, -1], probs[r, -1])) if i == 0: # highest prob trans best_trans = trans_out best_tran_ids = trans_out_ids return best_trans, best_tran_ids, u'\n'.join(beam_trans), best_trans_alignments[::-1] def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename): """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py Change the font in family param below. If the system font is not used, delete matplotlib font cache https://github.com/matplotlib/matplotlib/issues/3590 """ fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro') for xtick, idx in zip(ax.get_xticklabels(), source_ids): if idx == ac.UNK_ID: xtick.set_color('b') # target words -> row labels ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro') for ytick, idx in zip(ax.get_yticklabels(), target_ids): if idx == ac.UNK_ID: ytick.set_color('b') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(filename) plt.close('all') def translate(self): with tf.Graph().as_default(): train_model = self.get_model(ac.TRAINING) model = self.get_model(ac.VALIDATING) with tf.Session() as sess: self.logger.info('Restore model from {}'.format(self.model_file)) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, self.model_file) best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() ftrans = open(best_trans_file, 'w', 'utf-8') btrans = open(beam_trans_file, 'w', 'utf-8') self.logger.info('Start translating {}'.format(self.input_file)) start = time.time() count = 0 for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input(self.input_file): feed = { model.src_inputs: src_input, model.src_seq_lengths: src_seq_len } probs, scores, symbols, parents, alignments = sess.run([model.probs, model.scores, model.symbols, model.parents, model.alignments], feed_dict=feed) alignments = numpy.transpose(alignments, axes=(1, 0, 2)) probs = numpy.transpose(numpy.array(probs)) scores = numpy.transpose(numpy.array(scores)) symbols = numpy.transpose(numpy.array(symbols)) parents = numpy.transpose(numpy.array(parents)) best_trans, best_trans_ids, beam_trans, best_trans_alignments = self.get_trans(probs, scores, symbols, parents, alignments, no_unk_src_toks) ftrans.write(best_trans + '\n') btrans.write(beam_trans + '\n\n') if self.plot_align: src_input = numpy.reshape(src_input, [-1]) if self.reverse: src_input = src_input[::-1] no_unk_src_toks = no_unk_src_toks[::-1] trans_toks = best_trans.split() best_trans_alignments = numpy.array(best_trans_alignments)[:len(trans_toks)] filename = '{}_{}.png'.format(self.input_file, count) self.plot_head_map(best_trans_alignments, trans_toks, best_trans_ids, no_unk_src_toks, src_input, filename) count += 1 if count % 100 == 0: self.logger.info(' Translating line {}, average {} seconds/sent'.format(count, (time.time() - start) / count)) ftrans.close() btrans.close() self.logger.info('Done translating {}, it takes {} minutes'.format(self.input_file, float(time.time() - start) / 60.0))
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( ' smoothed train perplexity: {}'.format(train_smooth_perp)) self.logger.info( ' true train perplexity: {}'.format(train_true_perp)) def run_log(self, b, e, batch_data): start = time.time() src_toks, trg_toks, targets = batch_data src_toks_cuda = src_toks.to(self.device) trg_toks_cuda = trg_toks.to(self.device) targets_cuda = targets.to(self.device) # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.normalize_loss == ac.LOSS_TOK: opt_loss = loss / (targets_cuda != ac.PAD_ID).type( loss.type()).sum() elif self.normalize_loss == ac.LOSS_BATCH: opt_loss = loss / targets_cuda.size()[0].type(loss.type()) else: opt_loss = loss opt_loss.backward() # clip gradient global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip']) # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().numpy().sum() loss = loss.cpu().detach().numpy() nll_loss = nll_loss.cpu().detach().numpy() self.num_batches_done += 1 self.log_train_loss += loss self.log_nll_loss += nll_loss self.log_train_weights += num_words self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_nll_loss += nll_loss self.epoch_weights += num_words self.epoch_time += time.time() - start if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_smooth_perp = self.log_train_loss / self.log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = self.log_nll_loss / self.log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.log_train_loss = 0. self.log_nll_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg smooth perp: {0:.2f}'.format(avg_smooth_perp)) self.logger.info( ' avg true perp: {0:.2f}'.format(avg_true_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) self.logger.info(' global norm: {0:.2f}'.format(global_norm)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.num_batches_done + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr def train(self): self.model.train() train_ids_file = self.data_manager.data_files['ids'] for e in range(self.max_epochs): b = 0 for batch_data in self.data_manager.get_batch( ids_file=train_ids_file, shuffle=True, num_preload=self.num_preload): b += 1 self.run_log(b, e, batch_data) if not self.val_per_epoch: self.maybe_validate() self.report_epoch(e + 1) if self.val_per_epoch and (e + 1) % self.validate_freq == 0: self.maybe_validate(just_validate=True) # validate 1 last time if not self.config['val_per_epoch']: self.maybe_validate(just_validate=True) self.logger.info('It is finally done, mate!') self.logger.info('Train smoothed perps:') self.logger.info(', '.join(map(str, self.train_smooth_perps))) self.logger.info('Train true perps:') self.logger.info(', '.join(map(str, self.train_true_perps))) numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save(join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.logger.info('Save final checkpoint') self.save_checkpoint() # Evaluate on test for checkpoint in self.data_manager.checkpoints: self.logger.info('Translate for {}'.format(checkpoint)) dev_file = self.data_manager.dev_files[checkpoint][ self.data_manager.src_lang] test_file = self.data_manager.test_files[checkpoint][ self.data_manager.src_lang] if exists(test_file): self.logger.info(' Evaluate on test') self.restart_to_best_checkpoint(checkpoint) self.validator.translate(self.model, test_file) self.logger.info(' Also translate dev') self.validator.translate(self.model, dev_file) def save_checkpoint(self): cpkt_path = join(self.config['save_to'], '{}.pth'.format(self.config['model_name'])) torch.save(self.model.state_dict(), cpkt_path) def restart_to_best_checkpoint(self, checkpoint): best_perp = numpy.min(self.validator.best_perps[checkpoint]) best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp) self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path)) self.model.load_state_dict(torch.load(best_cpkt_path)) def maybe_validate(self, just_validate=False): if self.num_batches_done % self.validate_freq == 0 or just_validate: self.save_checkpoint() self.validator.validate_and_save(self.model) # if doing annealing if self.config[ 'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0: cond = len( self.validator.perp_curve ) > self.patience and self.validator.perp_curve[-1] > max( self.validator.perp_curve[-1 - self.patience:-1]) if cond: metric = 'perp' scores = self.validator.perp_curve[-1 - self.patience:] scores = map(str, list(scores)) scores = ', '.join(scores) self.logger.info('Past {} are {}'.format(metric, scores)) # when don't use warmup, decay lr if dev not improve if self.lr * self.lr_decay >= self.config['min_lr']: self.logger.info( 'Anneal the learning rate from {} to {}'.format( self.lr, self.lr * self.lr_decay)) self.lr = self.lr * self.lr_decay for p in self.optimizer.param_groups: p['lr'] = self.lr
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.config['fixed_var_list'] = args.fixed_var_list self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.lr = self.config['lr'] self.max_epochs = self.config['max_epochs'] self.save_freq = self.config['save_freq'] self.cpkt_path = None self.validate_freq = None self.train_perps = [] self.saver = None self.train_m = None self.dev_m = None self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.validate_freq = ut.get_validation_frequency( self.data_manager.length_files[ac.TRAINING], self.config['validate_freq'], self.config['batch_size']) self.logger.info('Evaluate every {} batches'.format( self.validate_freq)) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d, seed=ac.SEED) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def reload_and_get_cpkt_saver(self, config, sess): cpkt_path = join(config['save_to'], '{}.cpkt'.format(config['model_name'])) saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=config['n_best'] + 1) # TF some version >= 0.11, no longer .cpkt so check for meta file instead if exists(cpkt_path + '.meta') and config['reload']: self.logger.info('Reload model from {}'.format(cpkt_path)) saver.restore(sess, cpkt_path) self.cpkt_path = cpkt_path self.saver = saver def train(self): tf.reset_default_graph() with tf.Graph().as_default(): tf.set_random_seed(ac.SEED) self.train_m = self.get_model(ac.TRAINING) self.dev_m = self.get_model(ac.VALIDATING) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) self.reload_and_get_cpkt_saver(self.config, sess) self.logger.info('Set learning rate to {}'.format(self.lr)) sess.run(tf.assign(self.train_m.lr, self.lr)) for e in xrange(self.max_epochs): for b, batch_data in self.data_manager.get_batch( mode=ac.TRAINING, num_batches=self.num_preload): self.run_log_save(sess, b, e, batch_data) self.maybe_validate(sess) self.report_epoch(e) self.logger.info('It is finally done, mate!') self.logger.info('Train perplexities:') self.logger.info(', '.join(map(str, self.train_perps))) numpy.save(join(self.config['save_to'], 'train_perps.npy'), self.train_perps) self.logger.info('Save final checkpoint') self.saver.save(sess, self.cpkt_path) # Evaluate on test if exists(self.data_manager.ids_files[ac.TESTING]): self.logger.info('Evaluate on test') best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) self.logger.info( 'Restore best cpkt from {}'.format(best_cpkt_path)) self.saver.restore(sess, best_cpkt_path) self.validator.evaluate(sess, self.dev_m, mode=ac.TESTING) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e + 1)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_perp = self.epoch_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_perp = numpy.exp(train_perp) if train_perp < 300 else float( 'inf') self.train_perps.append(train_perp) self.logger.info(' train perplexity: {}'.format(train_perp)) def sample_input(self, batch_data): # TODO: more on sampling src_sent = batch_data[0][0] src_length = batch_data[1][0] trg_sent = batch_data[2][0] target = batch_data[3][0] weight = batch_data[4][0] src_sent = map(self.src_ivocab.get, src_sent) src_sent = u' '.join(src_sent) trg_sent = map(self.trg_ivocab.get, trg_sent) trg_sent = u' '.join(trg_sent) target = map(self.trg_ivocab.get, target) target = u' '.join(target) weight = ' '.join(map(str, weight)) self.logger.info('Sample input data:') self.logger.info(u'Src: {}'.format(src_sent)) self.logger.info(u'Src len: {}'.format(src_length)) self.logger.info(u'Trg: {}'.format(trg_sent)) self.logger.info(u'Tar: {}'.format(target)) self.logger.info(u'W: {}'.format(weight)) def run_log_save(self, sess, b, e, batch_data): start = time.time() src_inputs, src_seq_lengths, trg_inputs, trg_targets, target_weights = batch_data feed = { self.train_m.src_inputs: src_inputs, self.train_m.src_seq_lengths: src_seq_lengths, self.train_m.trg_inputs: trg_inputs, self.train_m.trg_targets: trg_targets, self.train_m.target_weights: target_weights } loss, _ = sess.run([self.train_m.loss, self.train_m.train_op], feed) num_trg_words = numpy.sum(target_weights) self.num_batches_done += 1 self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_weights += num_trg_words self.log_train_loss += loss self.log_train_weights += num_trg_words self.epoch_time += time.time() - start if self.num_batches_done % (10 * self.log_freq) == 0: self.sample_input(batch_data) if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_word_perp = self.log_train_loss / self.log_train_weights avg_word_perp = numpy.exp( avg_word_perp) if avg_word_perp < 300 else float('inf') self.log_train_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg word perp: {0:.2f}'.format(avg_word_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) if self.num_batches_done % self.save_freq == 0: start = time.time() self.saver.save(sess, self.cpkt_path) self.logger.info('Save model to {}, takes {}'.format( self.cpkt_path, ut.format_seconds(time.time() - start))) def maybe_validate(self, sess): if self.num_batches_done % self.validate_freq == 0: self.validator.validate_and_save(sess, self.dev_m, self.saver) def _call_me_maybe(self): pass # NO
class Model(nn.Module): """Model""" def __init__(self, config, load_from=None): super(Model, self).__init__() self.config = config self.struct = self.config['struct'] self.decoder_mask = None self.data_manager = DataManager(config, init_vocab=(not load_from)) if load_from: self.load_state_dict(torch.load(load_from, map_location=ut.get_device()), do_init=True) else: self.init_embeddings() self.init_model() self.add_struct_params() # dict where keys are data_ptrs to dicts of parameter options # see https://pytorch.org/docs/stable/optim.html#per-parameter-options self.parameter_attrs = {} def init_embeddings(self): embed_dim = self.config['embed_dim'] tie_mode = self.config['tie_mode'] fix_norm = self.config['fix_norm'] max_src_len = self.config['max_src_length'] max_trg_len = self.config['max_trg_length'] device = ut.get_device() # get trg positonal embedding if not self.config['learned_pos_trg']: self.pos_embedding_trg = ut.get_position_embedding( embed_dim, max_trg_len) else: self.pos_embedding_trg = Parameter( torch.empty(max_trg_len, embed_dim, dtype=torch.float, device=device)) nn.init.normal_(self.pos_embedding_trg, mean=0, std=embed_dim**-0.5) # get word embeddings # TODO: src_vocab_mask is assigned but never used (?) #src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config) #self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(self.config, src_vocab_size, trg_vocab_size) #if tie_mode == ac.ALL_TIED: # src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0] self.src_vocab_mask = self.data_manager.vocab_masks[ self.data_manager.src_lang] self.trg_vocab_mask = self.data_manager.vocab_masks[ self.data_manager.trg_lang] src_vocab_size = self.src_vocab_mask.shape[0] trg_vocab_size = self.trg_vocab_mask.shape[0] self.out_bias = Parameter( torch.empty(trg_vocab_size, dtype=torch.float, device=device)) nn.init.constant_(self.out_bias, 0.) self.src_embedding = nn.Embedding(src_vocab_size, embed_dim) self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim) self.out_embedding = self.trg_embedding.weight if self.config['separate_embed_scales']: self.src_embed_scale = Parameter( torch.tensor([embed_dim**0.5], device=device)) self.trg_embed_scale = Parameter( torch.tensor([embed_dim**0.5], device=device)) else: self.src_embed_scale = self.trg_embed_scale = torch.tensor( [embed_dim**0.5], device=device) self.src_pos_embed_scale = torch.tensor([(embed_dim / 2)**0.5], device=device) self.trg_pos_embed_scale = torch.tensor( [1.], device=device ) # trg pos embedding already returns vector of norm sqrt(embed_dim/2) if self.config['learn_pos_scale']: self.src_pos_embed_scale = Parameter(self.src_pos_embed_scale) self.trg_pos_embed_scale = Parameter(self.trg_pos_embed_scale) if tie_mode == ac.ALL_TIED: self.src_embedding.weight = self.trg_embedding.weight if not fix_norm: nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5) else: d = 0.01 # pure magic nn.init.uniform_(self.src_embedding.weight, a=-d, b=d) nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d) def init_model(self): num_enc_layers = self.config['num_enc_layers'] num_enc_heads = self.config['num_enc_heads'] num_dec_layers = self.config['num_dec_layers'] num_dec_heads = self.config['num_dec_heads'] embed_dim = self.config['embed_dim'] ff_dim = self.config['ff_dim'] dropout = self.config['dropout'] norm_in = self.config['norm_in'] # get encoder, decoder self.encoder = Encoder(num_enc_layers, num_enc_heads, embed_dim, ff_dim, dropout=dropout, norm_in=norm_in) self.decoder = Decoder(num_dec_layers, num_dec_heads, embed_dim, ff_dim, dropout=dropout, norm_in=norm_in) # leave layer norm alone init_func = nn.init.xavier_normal_ if self.config[ 'weight_init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_ for m in [ self.encoder.self_atts, self.encoder.pos_ffs, self.decoder.self_atts, self.decoder.pos_ffs, self.decoder.enc_dec_atts ]: for p in m.parameters(): if p.dim() > 1: init_func(p) else: nn.init.constant_(p, 0.) def add_struct_params(self): self.struct_params = self.struct.get_params(self.config) if hasattr( self.struct, "get_params") else {} if self.config['learned_pos_src']: self.struct_params = { name: Parameter(x) for name, x in self.struct_params.items() } for name, x in self.struct_params.items(): if not name.endswith('__const__'): self.register_parameter(name, x) def get_decoder_mask(self, size): if self.decoder_mask is None or self.decoder_mask.size()[-1] < size: self.decoder_mask = torch.triu(torch.ones((1, 1, size, size), dtype=torch.bool, device=ut.get_device()), diagonal=1) return self.decoder_mask else: return self.decoder_mask[:, :, :size, :size] def get_pos_embedding_h(self, x): embed_dim = self.config['embed_dim'] pe = x.get_pos_embedding(embed_dim, **self.struct_params) if isinstance(pe, type(x)): pe = pe.flatten() return pe if torch.is_tensor(pe) else torch.stack( pe) # [bsz, embed_dim] def get_pos_embedding(self, max_len, structs=None): if structs is not None: pe = [self.get_pos_embedding_h(x) for x in structs] return torch.nn.utils.rnn.pad_sequence( pe, batch_first=True) # [bsz, max_len, embed_dim] else: return self.pos_embedding_trg[:max_len, :].unsqueeze( 0) # [1, max_len, embed_dim] def get_input(self, toks, structs=None, calc_reg=False): max_len = toks.size()[-1] embed_dim = self.config['embed_dim'] embeds = self.src_embedding if structs is not None else self.trg_embedding word_embeds = embeds(toks) # [bsz, max_len, embed_dim] embed_scale = self.trg_embed_scale if structs is None else self.src_embed_scale if self.config['fix_norm']: word_embeds = ut.normalize(word_embeds, scale=False) else: word_embeds = word_embeds * embed_scale pos_embeds = self.get_pos_embedding(max_len, structs) pe_scale = self.src_pos_embed_scale if structs is not None else self.trg_pos_embed_scale reg_penalty = 0.0 if calc_reg: reg_penalty = self.struct.get_reg_penalty( pos_embeds, toks != ac.PAD_ID) * self.config['pos_norm_penalty'] sinusoidal_pe = self.get_pos_embedding( max_len) if structs is not None and self.config[ 'add_sinusoidal_pe_src'] else 0 return word_embeds + sinusoidal_pe + pos_embeds * pe_scale, reg_penalty def get_encoder_masks(self, src_toks, src_structs): encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze( 2) # [bsz, 1, 1, max_src_len] if hasattr(self.struct, "get_enc_mask"): encoder_mask_down = self.struct.get_enc_mask( src_toks, src_structs, self.config['num_enc_heads'], **self.struct_params) else: encoder_mask_down = encoder_mask return encoder_mask, encoder_mask_down def forward(self, src_toks, src_structs, trg_toks, targets, b=None, e=None): encoder_mask, encoder_mask_down = self.get_encoder_masks( src_toks, src_structs) decoder_mask = self.get_decoder_mask(trg_toks.size()[-1]) encoder_inputs, reg_penalty = self.get_input(src_toks, src_structs, calc_reg=hasattr( self.struct, "get_reg_penalty")) encoder_outputs = self.encoder(encoder_inputs, encoder_mask_down) decoder_inputs, _ = self.get_input(trg_toks) decoder_outputs = self.decoder(decoder_inputs, decoder_mask, encoder_outputs, encoder_mask) logits = self.logit_fn(decoder_outputs) neglprobs = F.log_softmax(logits, -1) neglprobs = neglprobs * self.trg_vocab_mask.reshape(1, -1) targets = targets.reshape(-1, 1) non_pad_mask = targets != ac.PAD_ID nll_loss = -neglprobs.gather(dim=-1, index=targets) #nll_loss = nll_loss[non_pad_mask] # speed nll_loss = nll_loss * non_pad_mask #smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask] smooth_loss = -neglprobs.sum(dim=-1, keepdim=True) * non_pad_mask nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() label_smoothing = self.config['label_smoothing'] if label_smoothing > 0: loss = ( 1.0 - label_smoothing ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.sum( ) else: loss = nll_loss loss += reg_penalty return {'loss': loss, 'nll_loss': nll_loss} def logit_fn(self, decoder_output): softmax_weight = self.out_embedding if not self.config[ 'fix_norm'] else ut.normalize(self.out_embedding, scale=True) logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias) logits = logits.reshape(-1, logits.size()[-1]) #logits[:, ~self.trg_vocab_mask] = -1e9 # speed logits.masked_fill_(~self.trg_vocab_mask.unsqueeze(0), -3e38) #-1e9) return logits def beam_decode(self, src_toks, src_structs): """Translate a minibatch of sentences. Arguments: src_toks[i,j] is the jth word of sentence i. Return: See encoders.Decoder.beam_decode """ #encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, max_src_len] encoder_mask, encoder_mask_down = self.get_encoder_masks( src_toks, src_structs) encoder_inputs, _ = self.get_input(src_toks, src_structs) encoder_outputs = self.encoder(encoder_inputs, encoder_mask_down) max_lengths1 = torch.sum(src_toks != ac.PAD_ID, dim=-1).type( src_toks.type()) + 50 max_lengths2 = torch.tensor(self.config['max_trg_length']).type( src_toks.type()) max_lengths = torch.min(max_lengths1, max_lengths2) def get_trg_inp(ids, time_step): ids = ids.type(src_toks.type()) word_embeds = self.trg_embedding(ids) if self.config['fix_norm']: word_embeds = ut.normalize(word_embeds, scale=False) else: word_embeds = word_embeds * self.trg_embed_scale pos_embeds = self.pos_embedding_trg[time_step, :].reshape(1, 1, -1) return word_embeds + pos_embeds * self.trg_pos_embed_scale def logprob(decoder_output): return F.log_softmax(self.logit_fn(decoder_output), dim=-1) if self.config['length_model'] == ac.GNMT_LENGTH_MODEL: length_model = ut.gnmt_length_model(self.config['length_alpha']) elif self.config['length_model'] == ac.LINEAR_LENGTH_MODEL: length_model = lambda t, p: p + self.config['length_alpha'] * t elif self.config['length_model'] == ac.NO_LENGTH_MODEL: length_model = lambda t, p: p else: raise ValueError('invalid length_model ' + str(self.config[length_model])) return self.decoder.beam_decode(encoder_outputs, encoder_mask, get_trg_inp, logprob, length_model, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.config['beam_size']) def load_state_dict(self, loaded_dict, do_init=False): state_dict = loaded_dict['model'] vocabs = loaded_dict['data_manager'] self.data_manager.load_state_dict(vocabs) if do_init: self.init_embeddings() self.init_model() self.add_struct_params() super().load_state_dict(state_dict) def save(self, fp=None): fp = fp or os.path.join(self.config['save_to'], self.config['model_name'] + '.pth') cpkt = { 'model': self.state_dict(), 'data_manager': self.data_manager.state_dict(), } torch.save(cpkt, fp) def translate(self, input_file_or_stream, best_output_stream, beam_output_stream, num_preload=ac.DEFAULT_NUM_PRELOAD, to_ids=False): return self.data_manager.translate(self, input_file_or_stream, best_output_stream, beam_output_stream, num_preload=num_preload, to_ids=to_ids)
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file if self.input_file is None or self.model_file is None or not os.path.exists( self.input_file) or not os.path.exists(self.model_file): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) self.translate() def ids_to_trans(self, trans_ids): words = [] word_ids = [] for idx, word_idx in enumerate(trans_ids): if word_idx == ac.EOS_ID: break words.append(self.data_manager.trg_ivocab[word_idx]) word_ids.append(word_idx) return u' '.join(words), word_ids def get_trans(self, probs, scores, symbols): sorted_rows = numpy.argsort(scores)[::-1] best_trans = None best_tran_ids = None beam_trans = [] for i, r in enumerate(sorted_rows): trans_ids = symbols[r] trans_out, trans_out_ids = self.ids_to_trans(trans_ids) beam_trans.append(u'{} {:.2f} {:.2f}'.format( trans_out, scores[r], probs[r])) if i == 0: # highest prob trans best_trans = trans_out best_tran_ids = trans_out_ids return best_trans, best_tran_ids, u'\n'.join(beam_trans) def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename): """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py Change the font in family param below. If the system font is not used, delete matplotlib font cache https://github.com/matplotlib/matplotlib/issues/3590 """ fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro') for xtick, idx in zip(ax.get_xticklabels(), source_ids): if idx == ac.UNK_ID: xtick.set_color('b') # target words -> row labels ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro') for ytick, idx in zip(ax.get_yticklabels(), target_ids): if idx == ac.UNK_ID: ytick.set_color('b') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(filename) plt.close('all') def translate(self): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = Model(self.config).to(device) self.logger.info('Restore model from {}'.format(self.model_file)) model.load_state_dict(torch.load(self.model_file)) model.eval() best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() num_sents = 0 with open(self.input_file, 'r') as f: for line in f: if line.strip(): num_sents += 1 all_best_trans = [''] * num_sents all_beam_trans = [''] * num_sents with torch.no_grad(): self.logger.info('Start translating {}'.format(self.input_file)) start = time.time() count = 0 for (src_toks, original_idxs) in self.data_manager.get_trans_input( self.input_file): src_toks_cuda = src_toks.to(device) rets = model.beam_decode(src_toks_cuda) for i, ret in enumerate(rets): probs = ret['probs'].cpu().detach().numpy().reshape([-1]) scores = ret['scores'].cpu().detach().numpy().reshape([-1]) symbols = ret['symbols'].cpu().detach().numpy() best_trans, best_trans_ids, beam_trans = self.get_trans( probs, scores, symbols) all_best_trans[original_idxs[i]] = best_trans + '\n' all_beam_trans[original_idxs[i]] = beam_trans + '\n\n' count += 1 if count % 100 == 0: self.logger.info( ' Translating line {}, average {} seconds/sent'. format(count, (time.time() - start) / count)) model.train() with open(best_trans_file, 'w') as ftrans, open(beam_trans_file, 'w') as btrans: ftrans.write(''.join(all_best_trans)) btrans.write(''.join(all_beam_trans)) self.logger.info('Done translating {}, it takes {} minutes'.format( self.input_file, float(time.time() - start) / 60.0))
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = getattr(configurations, args.proto)() self.reverse = self.config['reverse'] self.unk_repl = self.config['unk_repl'] self.logger = ut.get_logger(self.config['log_file']) self.input_file = args.input_file self.model_file = args.model_file self.plot_align = args.plot_align if self.input_file is None or self.model_file is None or not os.path.exists( self.input_file) or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') self.data_manager = DataManager(self.config) _, self.src_ivocab = self.data_manager.init_vocab( self.data_manager.src_lang) _, self.trg_ivocab = self.data_manager.init_vocab( self.data_manager.trg_lang) self.translate() def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode) def translate(self): with tf.Graph().as_default(): train_model = self.get_model(ac.TRAINING) model = self.get_model(ac.VALIDATING) with tf.Session() as sess: self.logger.info('Restore model from {}'.format( self.model_file)) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, self.model_file) best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() ftrans = open(best_trans_file, 'w', 'utf-8') btrans = open(beam_trans_file, 'w', 'utf-8') self.logger.info('Start translating {}'.format( self.input_file)) start = time.time() count = 0 for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input( self.input_file): feed = { model.src_inputs: src_input, model.src_seq_lengths: src_seq_len } probs, scores, symbols, parents, alignments = sess.run( [ model.probs, model.scores, model.symbols, model.parents, model.alignments ], feed_dict=feed) alignments = numpy.transpose(alignments, axes=(1, 0, 2)) probs = numpy.transpose(numpy.array(probs)) scores = numpy.transpose(numpy.array(scores)) symbols = numpy.transpose(numpy.array(symbols)) parents = numpy.transpose(numpy.array(parents)) best_trans, best_trans_ids, beam_trans, best_trans_alignments = ut.get_trans( probs, scores, symbols, parents, alignments, no_unk_src_toks, self.trg_ivocab, reverse=self.reverse, unk_repl=self.unk_repl) best_trans_wo_eos = best_trans.split()[:-1] best_trans_wo_eos = u' '.join(best_trans_wo_eos) ftrans.write(best_trans_wo_eos + '\n') btrans.write(beam_trans + '\n\n') if self.plot_align: src_input = numpy.reshape(src_input, [-1]) if self.reverse: src_input = src_input[::-1] no_unk_src_toks = no_unk_src_toks[::-1] trans_toks = best_trans.split() best_trans_alignments = numpy.array( best_trans_alignments)[:len(trans_toks)] filename = '{}_{}.png'.format(self.input_file, count) ut.plot_head_map(best_trans_alignments, trans_toks, best_tran_ids, no_unk_src_toks, src_input, filename) count += 1 if count % 100 == 0: self.logger.info( ' Translating line {}, average {} seconds/sent'. format(count, (time.time() - start) / count)) ftrans.close() btrans.close() self.logger.info( 'Done translating {}, it takes {} minutes'.format( self.input_file, float(time.time() - start) / 60.0))