def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length): if len(DATA_DIR) == 0: raise Exception('Please specify path to data directory in single_length_train.py!') lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES) real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]) global_step = tf.Variable(0, trainable=False) disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap, real_inputs_discrete, seq_length) merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length) disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step) saver = tf.train.Saver(tf.trainable_variables()) with tf.Session() as session: session.run(tf.initialize_all_variables()) if not is_first: print("Loading previous checkpoint...") internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length) model_and_data_serialization.optimistic_restore(session, latest_checkpoint(internal_checkpoint_dir, "checkpoint")) restore_config.set_restore_dir( load_from_curr_session=True) # global param, always load from curr session after finishing the first seq gen = inf_train_gen(lines, charmap) for iteration in range(iterations): start_time = time.time() # Train critic for i in range(CRITIC_ITERS): _data = next(gen) _disc_cost, _, real_scores = session.run( [disc_cost, disc_train_op, disc_real], feed_dict={real_inputs_discrete: _data} ) # Train G for i in range(GEN_ITERS): _data = next(gen) _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data}) print("iteration %s/%s"%(iteration, iterations)) print("disc cost %f"%_disc_cost) # Summaries if iteration % 100 == 99: _data = next(gen) summary_str = session.run( merged, feed_dict={real_inputs_discrete: _data} ) train_writer.add_summary(summary_str, global_step=iteration) fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete, feed_gt=True) log_samples(fake_samples, fake_scores, iteration, seq_length, "train") log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length, "gt") test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, inference_op, disc_on_inference, gen, real_inputs_discrete, feed_gt=False) # disc_on_inference, inference_op log_samples(test_samples, fake_scores, iteration, seq_length, "test") if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1: saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp") saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp") session.close()
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length): if len(DATA_DIR) == 0: raise Exception('Please specify path to data directory in single_length_train.py!') lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES) real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]) global_step = tf.Variable(0, trainable=False) disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length, gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL) merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length) disc_train_op, gen_train_op = get_optimization_ops( disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR) saver = tf.train.Saver(tf.trainable_variables()) # Use JIT XLA compilation to speed up calculations config=tf.ConfigProto( log_device_placement=False, allow_soft_placement=True) config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 with tf.Session(config=config) as session: session.run(tf.initialize_all_variables()) if not is_first: print("Loading previous checkpoint...") internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length) model_and_data_serialization.optimistic_restore(session, latest_checkpoint(internal_checkpoint_dir, "checkpoint")) restore_config.set_restore_dir( load_from_curr_session=True) # global param, always load from curr session after finishing the first seq gen = inf_train_gen(lines, charmap) _gen_cost_list = [] _disc_cost_list = [] _step_time_list = [] for iteration in range(iterations): start_time = time.time() # Train critic for i in range(CRITIC_ITERS): _data = next(gen) if FLAGS.GAN_TYPE.lower() == "fgan": _disc_cost, _, real_scores, _ = session.run( [disc_cost, disc_train_op, disc_real, other_ops["alpha_optimizer_op"]], feed_dict={real_inputs_discrete: _data} ) elif FLAGS.GAN_TYPE.lower() == "wgan": _disc_cost, _, real_scores = session.run( [disc_cost, disc_train_op, disc_real], feed_dict={real_inputs_discrete: _data} ) else: raise ValueError( "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE)) _disc_cost_list.append(_disc_cost) # Train G for i in range(GEN_ITERS): _data = next(gen) # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost # To measure convergence, gen_cost should start at a positive number and decrease # to zero. The lower, the better. _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data}) _gen_cost_list.append(_gen_cost) _step_time_list.append(time.time() - start_time) if FLAGS.PRINT_EVERY_STEP: print("iteration %s/%s"%(iteration, iterations)) print("disc cost {}"%_disc_cost) print("gen cost {}".format(_gen_cost)) print("total step time {}".format(time.time() - start_time)) # Summaries if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1: _data = next(gen) summary_str = session.run( merged, feed_dict={real_inputs_discrete: _data} ) tf.logging.warn("iteration %s/%s"%(iteration, iterations)) tf.logging.warn("disc cost {} gen cost {} average step time {}".format( np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list))) _gen_cost_list, _disc_cost_list, _step_time_list = [], [], [] train_writer.add_summary(summary_str, global_step=iteration) fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True) log_samples(fake_samples, fake_scores, iteration, seq_length, "train") log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length, "gt") test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, inference_op, disc_on_inference, gen, real_inputs_discrete, feed_gt=False) # disc_on_inference, inference_op log_samples(test_samples, fake_scores, iteration, seq_length, "test") if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1: saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp") saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp") session.close()
def evaluate(EVAL_FLAGS): # load dict & params _, charmap, inv_charmap = model_and_data_serialization.load_dataset( seq_length=32, b_lines=False) seq_length = EVAL_FLAGS.seq_len N = EVAL_FLAGS.num_samples ckp_list, config_list = get_models_list() print("ALL MODELS: %0s" % ckp_list) for ckp_path, config_path in zip(ckp_list, config_list): model_name = "%0s_%0s" % (ckp_path.split('/')[2], ckp_path.split('/')[4]) #filter by model name if not model_name.startswith(EVAL_FLAGS.prefix_filter): print("NOT EVALUATING [%0s]" % model_name) continue tf.reset_default_graph() print("EVALUATING [%0s]" % model_name) print("restoring config:") FLAGS.DISC_STATE_SIZE = restore_param_from_config( config_path, 'DISC_STATE_SIZE') FLAGS.GEN_STATE_SIZE = restore_param_from_config( config_path, 'GEN_STATE_SIZE') print("DISC_STATE_SIZE [%0d]" % FLAGS.DISC_STATE_SIZE) print("GEN_STATE_SIZE [%0d]" % FLAGS.GEN_STATE_SIZE) lines, _, _ = model_and_data_serialization.load_dataset( seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES, dataset='heldout') real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]) global_step = tf.Variable(0, trainable=False) disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective( charmap, real_inputs_discrete, seq_length) # train_pred -> run session train_pred_all = np.zeros( [N, BATCH_SIZE, seq_length, train_pred.shape[2]], dtype=np.float32) sess = tf.Session() sess.run(tf.initialize_all_variables()) # load checkpoints # internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(0) internal_checkpoint_dir = ckp_path model_and_data_serialization.optimistic_restore( sess, latest_checkpoint(internal_checkpoint_dir, "checkpoint")) restore_config.set_restore_dir( load_from_curr_session=True ) # global param, always load from curr session after finishing the first seq # create samples print('start creating samples..') test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples( sess, inv_charmap, inference_op, disc_on_inference, None, real_inputs_discrete, feed_gt=False) log_samples(test_samples, fake_scores, 666, seq_length, "test") print('starting BPC eval...') BPC_list = [] if EVAL_FLAGS.short_run: end_line = BATCH_SIZE * 100 # 100 batches else: end_line = len(lines) - BATCH_SIZE + 1 # all data for start_line in range(0, end_line, BATCH_SIZE): t0 = time.time() _data = np.array([[charmap[c] for c in l] for l in lines[start_line:start_line + BATCH_SIZE]]) # rand N noise vectors and for each one - calculate train_pred. for i in range(N): train_pred_i = sess.run( train_pred_for_eval, feed_dict={real_inputs_discrete: _data}) train_pred_all[i, :, :, :] = train_pred_i # take average on each time step (first dimension) train_pred_average = np.mean(train_pred_all, axis=0) # compute BPC (char-based perplexity) train_pred_average_2d = train_pred_average.reshape([ train_pred_average.shape[0] * train_pred_average.shape[1], train_pred_average.shape[2] ]) real_data = _data.reshape([_data.shape[0] * _data.shape[1]]) BPC = 0 epsilon = 1e-20 for i in range(real_data.shape[0]): BPC -= np.log2(train_pred_average_2d[i, real_data[i]] + epsilon) BPC /= real_data.shape[0] print("BPC of start_line %d/%d = %.2f" % (start_line, len(lines), BPC)) print("t_iter = %.2f" % (time.time() - t0)) BPC_list.append(BPC) np.save('BPC_list_temp.npy', BPC_list) BPC_final = np.mean(BPC_list) print("[%0s]BPC_final = %.2f\n" % (ckp_path, BPC_final)) np.save("%0s_BPC_list.npy" % model_name, BPC_list) np.save("%0s_BPC_final.npy" % model_name, BPC_final)