def train_RNNLM(_run): # maintain consistency between sacred config and experiment config config = ExperimentConfig(**_run.config) experiment_dump_path = os.path.join(experiment_path, str(_run._id), "tf_dump") if not os.path.exists(experiment_dump_path): os.makedirs(experiment_dump_path) if config.debug: print("Running in debug mode...") with tf.Graph().as_default(): # set random seed before the graph is built tf.set_random_seed(config.tf_random_seed) model = RNNLM_Model(config, load_corpus=True) init = tf.global_variables_initializer() saver = tf.train.Saver() tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True with tf.Session(config=tf_config) as session: best_val_pp = float('inf') best_val_epoch = 0 tf.summary.FileWriter(os.path.join(train_path, "TensorLog"), session.graph) session.run(init) for epoch in range(config.max_epochs): print('Epoch {}'.format(epoch)) start = time.time() train_pp = model.run_epoch( session, model.encoded_train, train_op=model.train_step) print('Training perplexity: {}'.format(train_pp)) print('Total Training time: {}'.format(time.time() - start)) valid_pp = model.run_epoch(session, model.encoded_dev) print('Validation perplexity: {}'.format(valid_pp)) if valid_pp < best_val_pp: best_val_pp = valid_pp best_val_epoch = epoch saver.save(session, os.path.join(experiment_dump_path, 'rnnlm.weights')) if epoch - best_val_epoch > config.early_stopping: break print('Total time: {}'.format(time.time() - start)) test_pp = model.run_epoch(session, model.encoded_test) print('Test perplexity: {}'.format(test_pp))
def dump_trained_weights(experiment, verbose): config = get_configs(experiment) # Still need to load the model to build graph # Graph is not saved RNNLM_Model(config) init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as session: session.run(init) saver.restore( session, os.path.join(experiment_path, str(experiment), "tf_dump", 'rnnlm.weights')) dump_vars = [ 'HMi', 'HMf', 'HMo', 'HMo', 'HMg', 'IMi', 'IMf', 'IMo', 'IMg', 'LM', 'bi', 'bf', 'bo', 'bg', 'b2' ] if config.share_embedding: dump_vars += ['PM'] else: dump_vars += ['UM'] if config.V_table: dump_vars.remove('LM') for i, seg in enumerate(config.embedding_seg): if i != 0: dump_vars += ['VT{}'.format(i)] dump_vars += ['LM{}'.format(i)] weight_dict = tf_weights_to_np_weights_dict(session, dump_vars) if config.D_softmax: # instead save the full patched embedding, split each block in "LM" into list of matrices blocks = [] col_s = 0 for size, s, e in config.embedding_seg: if e is None: e = weight_dict['LM'].shape[0] blocks.append(weight_dict['LM'][s:e, col_s:col_s + size]) col_s += size weight_dict['LM'] = blocks weight_dump_dir = os.path.join(experiment_path, str(experiment), "weights") dump_weights(weight_dict, weight_dump_dir, verbose)
def auto_generate_sentence(experiment=1): gen_config = get_configs(experiment) gen_config.batch_size = gen_config.num_steps = 1 gen_model = RNNLM_Model(gen_config) init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as session: session.run(init) saver.restore( session, os.path.join(experiment_path, str(experiment), "tf_dump", 'rnnlm.weights')) starting_text = '<eos>' while starting_text: sen = generate_sentence(session, gen_model, gen_config, starting_text=starting_text, temp=1.0) print(' '.join([w.split('/')[0] for w in sen])) starting_text = input('> ')
def create_RNNLM(config): return RNNLM_Model(config, load_corpus=True)