def disc_pre_train(text_data): train_set = gens.create_disc_train_set(gen_config, text_data, -1, None, gen_config.disc_data_batch_num) h_disc.hier_train(disc_config, evl_config, text_data.getVocabularySize(), train_set)
def al_train(text_data): with tf.Session() as sess: train_set = gens.create_train_set(gen_config, text_data) total_qa_size = 0 for i, set in enumerate(train_set): length = len(set) print("Generator train_set_{} len: {}".format(i, length)) total_qa_size += length print("Generator train_set total size is {} QA".format(total_qa_size)) train_bucket_sizes = [ len(train_set[b]) for b in range(len(gen_config.buckets)) ] train_total_size = float(sum(train_bucket_sizes)) train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes)) ] vocab_size = text_data.getVocabularySize() disc_model = h_disc.create_model(sess, disc_config, vocab_size, disc_config.name_model) gen_model = gens.create_model(sess, gen_config, vocab_size, forward_only=False, name_scope=gen_config.name_model) current_step = 0 step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 gen_loss_summary = tf.Summary() disc_loss_summary = tf.Summary() gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir, sess.graph) disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir, sess.graph) while True: current_step += 1 random_number_01 = np.random.random_sample() bucket_id = min([ i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01 ]) start_time = time.time() print( "==================Update Discriminator: %d==================" % current_step) for i in range(D_STEPS): print( "=============It's the %d time update Discriminator in current step=============" % (i + 1)) # 1. Sample (X,Y) from real data and sample ^Y from G(*|X) query_set, answer_set, gen_set = gens.create_disc_train_set( gen_config, text_data, bucket_id, train_set, 1, sess, gen_model) b_query, b_answer, b_gen = query_set[bucket_id], answer_set[ bucket_id], gen_set[bucket_id] train_query, train_answer, train_labels = h_disc.hier_get_batch( disc_config, len(b_query) - 1, b_query, b_answer, b_gen) train_query = np.transpose(train_query) train_answer = np.transpose(train_answer) _, disc_step_loss = disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=False) disc_loss += disc_step_loss / ( D_STEPS * disc_config.steps_per_checkpoint) if i == D_STEPS - 1: print("disc_step_loss: ", disc_step_loss) print("==================Update Generator: %d==================" % current_step) for j in range(G_STEPS): print( "=============It's the %d time update Generator in current step=============" % (j + 1)) encoder_inputs, decoder_inputs, target_weights,\ source_inputs, source_outputs = gens.get_batch(gen_config, train_set, bucket_id, gen_config.batch_size, text_data) decoder_inputs_negative = get_negative_decoder_inputs( sess, gen_model, encoder_inputs, decoder_inputs, target_weights, bucket_id) decoder_inputs_negative = np.transpose(decoder_inputs_negative) train_query, train_answer, train_labels = [], [], [] for query, answer in zip(source_inputs, source_outputs): train_query.append(query) train_answer.append(answer) train_labels.append(1) for _ in range(gen_config.beam_size): gen_set = get_negative_decoder_inputs(sess, gen_model, encoder_inputs, decoder_inputs, target_weights, bucket_id, mc_search=True) for i, output in enumerate(gen_set): train_query.append(train_query[i]) train_answer.append(output) train_labels.append(0) train_query = np.transpose(train_query) train_answer = np.transpose(train_answer) reward, _ = disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=True) batch_reward += reward / gen_config.steps_per_checkpoint print("step_reward: ", reward) gan_adjusted_loss, gen_step_loss, _ = gen_model.step( sess, encoder_inputs, decoder_inputs_negative, target_weights, bucket_id, forward_only=False, reward=reward, up_reward=True, debug=True) gen_loss += gen_step_loss / gen_config.steps_per_checkpoint print("gen_step_loss: ", gen_step_loss) print("gen_step_adjusted_loss: ", gan_adjusted_loss) t_adjusted_loss, t_step_loss, a = gen_model.step( sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False) t_loss += t_step_loss / (G_STEPS * gen_config.steps_per_checkpoint) print("t_step_loss: ", t_step_loss) print("t_adjusted_loss", t_adjusted_loss) if current_step % gen_config.steps_per_checkpoint == 0: step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint print( "current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f " % (current_step, step_time, disc_loss, gen_loss, t_loss, batch_reward)) disc_loss_value = disc_loss_summary.value.add() disc_loss_value.tag = disc_config.name_loss disc_loss_value.simple_value = float(disc_loss) disc_writer.add_summary(disc_loss_summary, int(sess.run(disc_model.global_step))) gen_global_steps = sess.run(gen_model.global_step) gen_loss_value = gen_loss_summary.value.add() gen_loss_value.tag = gen_config.name_loss gen_loss_value.simple_value = float(gen_loss) t_loss_value = gen_loss_summary.value.add() t_loss_value.tag = gen_config.teacher_loss t_loss_value.simple_value = float(t_loss) batch_reward_value = gen_loss_summary.value.add() batch_reward_value.tag = gen_config.reward_name batch_reward_value.simple_value = float(batch_reward) gen_writer.add_summary(gen_loss_summary, int(gen_global_steps)) if current_step % (gen_config.steps_per_checkpoint * 4) == 0: print("current_steps: %d, save disc model" % current_step) disc_ckpt_dir = os.path.abspath( os.path.join(disc_config.train_dir, "checkpoints")) if not os.path.exists(disc_ckpt_dir): os.makedirs(disc_ckpt_dir) disc_model_path = os.path.join(disc_ckpt_dir, "disc.model") disc_model.saver.save(sess, disc_model_path, global_step=disc_model.global_step) print("current_steps: %d, save gen model" % current_step) gen_ckpt_dir = os.path.abspath( os.path.join(gen_config.train_dir, "checkpoints")) if not os.path.exists(gen_ckpt_dir): os.makedirs(gen_ckpt_dir) gen_model_path = os.path.join(gen_ckpt_dir, "gen.model") gen_model.saver.save(sess, gen_model_path, global_step=gen_model.global_step) step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 sys.stdout.flush()