def score_model(source_file, target_file, scorer_settings, options): scores = [] for option in options: g = tf.Graph() with g.as_default(): tf_config = tf.ConfigProto() tf_config.allow_soft_placement = True with tf.Session(config=tf_config) as sess: logging.info('Building model...') model = StandardModel(option) saver = nmt.init_or_restore_variables(option, sess) text_iterator = TextIterator( source=source_file.name, target=target_file.name, source_dicts=option.source_dicts, target_dict=option.target_dict, batch_size=scorer_settings.b, maxlen=float('inf'), source_vocab_sizes=option.source_vocab_sizes, target_vocab_size=option.target_vocab_size, use_factor=(option.factors > 1), sort_by_length=False) losses = nmt.calc_loss_per_sentence( option, sess, text_iterator, model, normalization_alpha=scorer_settings.normalization_alpha) scores.append(losses) return scores
def theano_to_tensorflow_model(in_path, out_path): saved_model = np.load(in_path) config = theano_to_tensorflow_config(in_path) th2tf = construct_parameter_map(config) with tf.Session() as sess: logging.info('Building model...') model = StandardModel(config) saver = nmt.init_or_restore_variables(config, sess) seen = set() assign_ops = [] for key in saved_model.keys(): # ignore adam parameters if key.startswith('adam'): continue tf_name = th2tf[key] if tf_name is not None: assert tf_name not in seen seen.add(tf_name) tf_var = tf.get_default_graph().get_tensor_by_name(tf_name) if (sess.run(tf.shape(tf_var)) != saved_model[key].shape).any(): print "mismatch for", tf_name, key, saved_model[ key].shape, sess.run(tf.shape(tf_var)) assign_ops.append(tf.assign(tf_var, saved_model[key])) else: print "Not saving", key, "because no TF equivalent" sess.run(assign_ops) saver.save(sess, save_path=out_path) print "The following TF variables were not assigned (excluding Adam vars):" print "You should see only 'beta1_power', 'beta2_power' and 'time' variable listed" for tf_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): if tf_var.name not in seen and 'Adam' not in tf_var.name: print tf_var.name
def _load_models(self, process_id, sess): """ Loads models and returns them """ logging.debug("Process '%s' - Loading models\n" % (process_id)) import tensorflow as tf models = [] for i, options in enumerate(self._options): with tf.variable_scope("model%d" % i) as scope: model = StandardModel(options) saver = init_or_restore_variables(options, sess, ensemble_scope=scope) models.append(model) logging.info("NOTE: Length of translations is capped to {}".format(self._options[0].translation_maxlen)) return models
def validate_helper(config, sess): logging.info('Building model...') model = StandardModel(options) saver = init_or_restore_variables(config, sess) valid_text_iterator = TextIterator( source=config.valid_source_dataset, target=config.valid_target_dataset, source_dicts=config.source_dicts, target_dict=config.target_dict, batch_size=config.valid_batch_size, maxlen=config.maxlen, source_vocab_sizes=config.source_vocab_sizes, target_vocab_size=config.target_vocab_size, shuffle_each_epoch=False, sort_by_length=False, #TODO use_factor=(config.factors > 1), maxibatch_size=config.maxibatch_size) costs = validate(config, sess, valid_text_iterator, model) lines = open(config.valid_target_dataset).readlines() for cost, line in zip(costs, lines): logging.info("{0} {1}".format(cost, line.strip()))
def train(config, sess): assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \ "MAP training requires a prior model file: Use command-line option --prior_model" logging.info('Building model...') model = StandardModel(config) if config.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate) else: logging.error('No valid optimizer defined: {}'.format( config.optimizer)) sys.exit(1) init = tf.zeros_initializer(dtype=tf.int32) global_step = tf.get_variable('time', [], initializer=init, trainable=False) if config.summaryFreq: summary_dir = (config.summary_dir if config.summary_dir is not None else os.path.abspath(os.path.dirname(config.saveto))) writer = tf.summary.FileWriter(summary_dir, sess.graph) else: writer = None updater = ModelUpdater(config, model, optimizer, global_step, writer) saver, progress = init_or_restore_variables(config, sess, train=True) global_step.load(progress.uidx, sess) #save model options config_as_dict = OrderedDict(sorted(vars(config).items())) json.dump(config_as_dict, open('%s.json' % config.saveto, 'wb'), indent=2) text_iterator, valid_text_iterator = load_data(config) _, _, num_to_source, num_to_target = load_dictionaries(config) total_loss = 0. n_sents, n_words = 0, 0 last_time = time.time() logging.info("Initial uidx={}".format(progress.uidx)) for progress.eidx in xrange(progress.eidx, config.max_epochs): logging.info('Starting epoch {0}'.format(progress.eidx)) for source_sents, target_sents in text_iterator: print("") print("") print("") print("########## Source Sents ############") print(source_sents) print("") print("") print("") print("########## Target Sents ############") print(target_sents) if len(source_sents[0][0]) != config.factors: logging.error( 'Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n' .format(config.factors, len(source_sents[0][0]))) sys.exit(1) x_in, x_mask_in, y_in, y_mask_in = util.prepare_data( source_sents, target_sents, config.factors, maxlen=None) if x_in is None: logging.info( 'Minibatch with zero sample under length {0}'.format( config.maxlen)) continue write_summary_for_this_batch = config.summaryFreq and ( (progress.uidx % config.summaryFreq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0)) (factors, seqLen, batch_size) = x_in.shape loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in, write_summary_for_this_batch) total_loss += loss n_sents += batch_size n_words += int(numpy.sum(y_mask_in)) progress.uidx += 1 if config.dispFreq and progress.uidx % config.dispFreq == 0: duration = time.time() - last_time disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]') logging.info( '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}' .format(disp_time, progress.eidx, progress.uidx, total_loss / n_words, n_words / duration, n_sents / duration)) last_time = time.time() total_loss = 0. n_sents = 0 n_words = 0 if config.sampleFreq and progress.uidx % config.sampleFreq == 0: x_small, x_mask_small, y_small = x_in[:, :, : 10], x_mask_in[:, : 10], y_in[:, : 10] samples = model.sample(sess, x_small, x_mask_small) assert len(samples) == len(x_small.T) == len( y_small.T), (len(samples), x_small.shape, y_small.shape) for xx, yy, ss in zip(x_small.T, y_small.T, samples): source = util.factoredseq2words(xx, num_to_source) target = util.seq2words(yy, num_to_target) sample = util.seq2words(ss, num_to_target) logging.info('SOURCE: {}'.format(source)) logging.info('TARGET: {}'.format(target)) logging.info('SAMPLE: {}'.format(sample)) if config.beamFreq and progress.uidx % config.beamFreq == 0: x_small, x_mask_small, y_small = x_in[:, :, : 10], x_mask_in[:, : 10], y_in[:, : 10] samples = model.beam_search(sess, x_small, x_mask_small, config.beam_size) # samples is a list with shape batch x beam x len assert len(samples) == len(x_small.T) == len( y_small.T), (len(samples), x_small.shape, y_small.shape) for xx, yy, ss in zip(x_small.T, y_small.T, samples): source = util.factoredseq2words(xx, num_to_source) target = util.seq2words(yy, num_to_target) logging.info('SOURCE: {}'.format(source)) logging.info('TARGET: {}'.format(target)) for i, (sample_seq, cost) in enumerate(ss): sample = util.seq2words(sample_seq, num_to_target) msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format( i, sample, cost, len(sample), cost / len(sample)) logging.info(msg) if config.validFreq and progress.uidx % config.validFreq == 0: costs = validate(config, sess, valid_text_iterator, model) # validation loss is mean of normalized sentence log probs valid_loss = sum(costs) / len(costs) if (len(progress.history_errs) == 0 or valid_loss < min(progress.history_errs)): progress.history_errs.append(valid_loss) progress.bad_counter = 0 saver.save(sess, save_path=config.saveto) progress_path = '{0}.progress.json'.format(config.saveto) progress.save_to_json(progress_path) else: progress.history_errs.append(valid_loss) progress.bad_counter += 1 if progress.bad_counter > config.patience: logging.info('Early Stop!') progress.estop = True break if config.valid_script is not None: score = validate_with_script(sess, model, config, valid_text_iterator) need_to_save = ( score is not None and (len(progress.valid_script_scores) == 0 or score > max(progress.valid_script_scores))) if score is None: score = 0.0 # ensure a valid value is written progress.valid_script_scores.append(score) if need_to_save: save_path = config.saveto + ".best-valid-script" saver.save(sess, save_path=save_path) progress_path = '{}.progress.json'.format(save_path) progress.save_to_json(progress_path) if config.saveFreq and progress.uidx % config.saveFreq == 0: saver.save(sess, save_path=config.saveto, global_step=progress.uidx) progress_path = '{0}-{1}.progress.json'.format( config.saveto, progress.uidx) progress.save_to_json(progress_path) if config.finish_after and progress.uidx % config.finish_after == 0: logging.info("Maximum number of updates reached") saver.save(sess, save_path=config.saveto, global_step=progress.uidx) progress.estop = True progress_path = '{0}-{1}.progress.json'.format( config.saveto, progress.uidx) progress.save_to_json(progress_path) break if progress.estop: break
config.source_dicts = config.dictionaries[:-1] config.source_vocab_sizes = vocab_sizes[:-1] config.target_dict = config.dictionaries[-1] config.target_vocab_size = vocab_sizes[-1] # set the model version config.model_version = 0.2 config.theano_compat = False return config if __name__ == "__main__": # set up logging level = logging.INFO logging.basicConfig(level=level, format='%(levelname)s: %(message)s') config = parse_args() logging.info(config) with tf.Session() as sess: if config.translate_valid: logging.info('Building model...') model = StandardModel(config) saver = init_or_restore_variables(config, sess) translate(sess, model, config) elif config.run_validation: validate_helper(config, sess) else: train(config, sess)