def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity( tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode == "train": os.makedirs(FLAGS.log_root) else: raise Exception( "Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'vocab_size', 'dataset', 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num', 'max_enc_num', 'max_dec_steps', 'max_enc_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'srl_max_dec_seq_len', 'srl_max_dec_sen_num', 'srl_max_enc_seq_len', 'srl_max_enc_sen_num' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_srl_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'sc_max_dec_seq_len', 'sc_max_enc_seq_len' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_sc_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data sc_batcher = Sc_GenBatcher(vocab, hps_sc_generator) tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'train': print("Start pre-training......") sc_model = Sc_Generator(hps_sc_generator, vocab) sess_sc, saver_sc, train_dir_sc = setup_training_sc_generator(sc_model) sc_generated = Generated_sc_sample(sc_model, vocab, sess_sc) print("Start pre-training generator......") run_pre_train_sc_generator(sc_model, sc_batcher, 40, sess_sc, saver_sc, train_dir_sc, sc_generated) if not os.path.exists("data/" + str(0) + "/"): os.mkdir("data/" + str(0) + "/") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-train"), "data/" + str(0) + "/train_skeleton.txt") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-valid"), "data/" + str(0) + "/valid_skeleton.txt") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-test"), "data/" + str(0) + "/test_skeleton.txt") merge("data/story/train_process.txt", "data/0/train_skeleton.txt", "data/0/train.txt") merge("data/story/validation_process.txt", "data/0/valid_skeleton.txt", "data/0/valid.txt") merge("data/story/test_process.txt", "data/0/test_skeleton.txt", "data/0/test.txt") ################################################################################################# batcher = GenBatcher(vocab, hps_generator) srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator) print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, sess_ge) print("Start pre-training generator......") run_pre_train_generator(model, batcher, 30, sess_ge, saver_ge, train_dir_ge, generated) ################################################################################################## srl_generator_model = Srl_Generator(hps_srl_generator, vocab) sess_srl_ge, saver_srl_ge, train_dir_srl_ge = setup_training_srl_generator( srl_generator_model) util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") util.load_ckpt(saver_sc, sess_sc, ckpt_dir="train-sc-generator") srl_generated = Generated_srl_sample(srl_generator_model, vocab, sess_srl_ge) whole_generated = Generated_whole_sample(model, srl_generator_model, vocab, sess_ge, sess_srl_ge, batcher, srl_batcher) print("Start pre-training srl_generator......") run_pre_train_srl_generator(srl_generator_model, batcher, srl_batcher, 20, sess_srl_ge, saver_srl_ge, train_dir_srl_ge, srl_generated, whole_generated) loss_window = 0 t0 = time.time() print("begin reinforcement learning:") for epoch in range(10): loss_window = 0.0 batcher = GenBatcher(vocab, hps_generator) srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator) batches = batcher.get_batches(mode='train') srl_batches = srl_batcher.get_batches(mode='train') sc_batches = sc_batcher.get_batches(mode='train') len_sc = len(sc_batches) for i in range(min(len(batches), len(srl_batches))): current_batch = batches[i] current_srl_batch = srl_batches[i] current_sc_batch = sc_batches[i % (len_sc - 1)] results = model.run_pre_train_step(sess_ge, current_batch) loss_list = results['without_average_loss'] example_skeleton_list = current_batch.original_review_outputs example_text_list = current_batch.original_target_sentences new_batch = sc_batcher.get_text_queue(example_skeleton_list, example_text_list, loss_list) results_sc = sc_model.run_rl_train_step(sess_sc, new_batch) loss = results_sc['loss'] loss_window += loss results_srl = srl_generator_model.run_pre_train_step( sess_srl_ge, current_srl_batch) loss_list_srl = results_srl['without_average_loss'] example_srl_text_list = current_srl_batch.orig_outputs example_skeleton_srl_list = current_srl_batch.orig_inputs new_batch = sc_batcher.get_text_queue( example_skeleton_srl_list, example_srl_text_list, loss_list_srl) results_sc = sc_model.run_rl_train_step(sess_sc, new_batch) loss = results_sc['loss'] loss_window += loss results_sc = sc_model.run_rl_train_step( sess_sc, current_sc_batch) loss = results_sc['loss'] loss_window += loss train_step = results['global_step'] if train_step % 100 == 0: t1 = time.time() tf.logging.info( 'seconds for %d training generator step: %.3f ', train_step, (t1 - t0) / 300) t0 = time.time() tf.logging.info('loss: %f', loss_window / 100) # print the loss to screen loss_window = 0.0 train_srl_step = results_srl['global_step'] if train_srl_step % 10000 == 0: saver_sc.save(sess_sc, train_dir_sc + "/model", global_step=results_sc['global_step']) saver_ge.save(sess_ge, train_dir_ge + "/model", global_step=train_step) saver_srl_ge.save(sess_srl_ge, train_dir_srl_ge + "/model", global_step=train_srl_step) srl_generated.generator_max_example( srl_batcher.get_batches("validation"), "to_seq_max_generated/valid/" + str(int(train_srl_step / 30000)) + "_positive", "to_seq_max_generated/valid/" + str(int(train_srl_step / 30000)) + "_negative") srl_generated.generator_max_example( srl_batcher.get_batches("test"), "to_seq_max_generated/test/" + str(int(train_srl_step / 30000)) + "_positive", "to_seq_max_generated/test/" + str(int(train_srl_step / 30000)) + "_negative") whole_generated.generator_max_example( batcher.get_batches("test-validation"), "max_generated_final/valid/" + str(int(train_srl_step / 30000)) + "_positive", "max_generated_final/valid/" + str(int(train_srl_step / 30000)) + "_negative") whole_generated.generator_max_example( batcher.get_batches("test-test"), "max_generated_final/test/" + str(int(train_srl_step / 30000)) + "_positive", "max_generated_final/test/" + str(int(train_srl_step / 30000)) + "_negative") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-train"), "data/" + str(0) + "/train_skeleton.txt") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-valid"), "data/" + str(0) + "/valid_skeleton.txt") sc_generated.generator_max_example_test( sc_batcher.get_batches("pre-test"), "data/" + str(0) + "/test_skeleton.txt") merge("data/story/train_process.txt", "data/0/train_skeleton.txt", "data/0/train.txt") merge("data/story/validation_process.txt", "data/0/valid_skeleton.txt", "data/0/valid.txt") merge("data/story/test_process.txt", "data/0/test_skeleton.txt", "data/0/test.txt") else: raise ValueError("The 'mode' flag must be one of train/eval/decode")
def main(unused_argv): # %% # choose what level of logging you want tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # 創建字典 vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): if key in hparam_list: hps_dict[key] = val.value # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: hps_dict[key] = val.value # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # # 取出最小batch size 的資料量 batcher = GenBatcher(vocab, hps_generator) # print(batcher.train_batch[0].original_review_inputs) # print(len(batcher.train_batch[0].original_review_inputs)) tf.set_random_seed(123) # %% if FLAGS.mode == 'train_generator': # print("Start pre-training ......") ge_model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model) generated = Generated_sample(ge_model, vocab, batcher, sess_ge) print("Start pre-training generator......") run_pre_train_generator(ge_model, batcher, 300, sess_ge, saver_ge, train_dir_ge) # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") print("finish load train-generator") print("Generating negative examples......") generator_graph = tf.Graph() with generator_graph.as_default(): util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") print("finish load train-generator") generated.generator_train_negative_example() generated.generator_test_negative_example() print("finish write") elif FLAGS.mode == 'train_discriminator': # print("Start pre-training ......") model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator( model_dis) print("Start pre-training discriminator......") if not os.path.exists("discriminator_result"): os.mkdir("discriminator_result") run_pre_train_discriminator(model_dis, dis_batcher, 1000, sess_dis, saver_dis, train_dir_dis) elif FLAGS.mode == "adversarial_train": generator_graph = tf.Graph() discriminatorr_graph = tf.Graph() print("Start adversarial-training......") # tf.reset_default_graph() with generator_graph.as_default(): model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") print("finish load train-generator") with discriminatorr_graph.as_default(): model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator( model_dis) util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator") print("finish load train-discriminator") print("Start adversarial training......") if not os.path.exists("train_sample_generated"): os.mkdir("train_sample_generated") if not os.path.exists("test_max_generated"): os.mkdir("test_max_generated") if not os.path.exists("test_sample_generated"): os.mkdir("test_sample_generated") whole_decay = False for epoch in range(100): print('開始訓練') batches = batcher.get_batches(mode='train') for step in range(int(len(batches) / 14)): run_train_generator(model, model_dis, sess_dis, batcher, dis_batcher, batches[step * 14:(step + 1) * 14], sess_ge, saver_ge, train_dir_ge) generated.generator_sample_example( "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14) tf.logging.info("test performance: ") tf.logging.info("epoch: " + str(epoch) + " step: " + str(step)) # print("evaluate the diversity of DP-GAN (decode based on max probability)") # generated.generator_test_sample_example( # "test_sample_generated/" + # str(epoch) + "epoch_step" + str(step) + "_temp_positive", # "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14) # # print("evaluate the diversity of DP-GAN (decode based on sampling)") # generated.generator_test_max_example( # "test_max_generated/" + # str(epoch) + "epoch_step" + str(step) + "_temp_positive", # "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14) dis_batcher.train_queue = [] for i in range(epoch + 1): for j in range(step + 1): dis_batcher.train_queue += dis_batcher.fill_example_queue( "train_sample_generated/" + str(i) + "epoch_step" + str(j) + "_temp_positive/*") dis_batcher.train_queue += dis_batcher.fill_example_queue( "train_sample_generated/" + str(i) + "epoch_step" + str(j) + "_temp_negative/*") dis_batcher.train_batch = dis_batcher.create_batches( mode="train", shuffleis=True) whole_decay = run_train_discriminator( model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"), sess_dis, saver_dis, train_dir_dis, whole_decay) elif FLAGS.mode == "test_language_model": ge_model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model) util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") print("finish load train-generator") # generator_graph = tf.Graph() # with generator_graph.as_default(): # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") # print("finish load train-generator") # jieba.load_userdict('dir.txt') inputs = '' while inputs != "close": inputs = input("Enter your ask: ") sentence = segmentor.segment(t2s.convert(inputs)) # sentence = jieba.cut(inputs) sentence = (" ".join(sentence)) sentence = s2t.convert(sentence) print(sentence) sentence = sentence.split() enc_input = [vocab.word2id(w) for w in sentence] enc_lens = np.array([len(enc_input)]) enc_input = np.array([enc_input]) out_sentence = ('[START]').split() dec_batch = [vocab.word2id(w) for w in out_sentence] #dec_batch = [2] + dec_batch # dec_batch.append(3) while len(dec_batch) < 40: dec_batch.append(1) dec_batch = np.array([dec_batch]) dec_batch = np.resize(dec_batch, (1, 1, 40)) dec_lens = np.array([len(dec_batch)]) if (FLAGS.beamsearch == 'beamsearch_train'): result = ge_model.run_test_language_model( sess_ge, enc_input, enc_lens, dec_batch, dec_lens) # print(result['generated']) # print(result['generated'].shape) output_ids = result['generated'][0] decoded_words = data.outputids2words(output_ids, vocab, None) print("decoded_words :", decoded_words) else: results = ge_model.run_test_beamsearch_example( sess_ge, enc_input, enc_lens, dec_batch, dec_lens) beamsearch_outputs = results['beamsearch_outputs'] for i in range(5): predict_list = np.ndarray.tolist(beamsearch_outputs[:, :, i]) predict_list = predict_list[0] predict_seq = [vocab.id2word(idx) for idx in predict_list] decoded_words = " ".join(predict_seq).split() # decoded_words = decoded_words try: if decoded_words[0] == '[STOPDOC]': decoded_words = decoded_words[1:] # index of the (first) [STOP] symbol fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if decoded_words[-1] != '.' and decoded_words[ -1] != '!' and decoded_words[-1] != '?': decoded_words.append('.') decoded_words_all = [] decoded_output = ' '.join( decoded_words).strip() # single string decoded_words_all.append(decoded_output) decoded_words_all = ' '.join(decoded_words_all).strip() decoded_words_all = decoded_words_all.replace("[UNK] ", "") decoded_words_all = decoded_words_all.replace("[UNK]", "") decoded_words_all = decoded_words_all.replace(" ", "") decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all) decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all) if decoded_words_all.startswith(','): decoded_words_all = decoded_words_all[1:] print("The resonse : {}".format(decoded_words_all))
def main(): ################################ ## 第一模块:数据准备工作 data_ = data.Data(args.data_dir, args.vocab_size) # 对ICD tree 处理 parient_children, level2_parients, leafNodes, adj, node2id, hier_dicts = utils.build_tree( os.path.join(args.data_dir, 'note_labeled.csv')) graph = utils.generate_graph(parient_children, node2id) args.node2id = node2id args.adj = torch.Tensor(adj).long().to(args.device) args.leafNodes = leafNodes args.hier_dicts = hier_dicts # TODO batcher对象的细节 g_batcher = GenBatcher(data_, args) ################################# ## 第二模块: 创建G模型,并预训练 G模型 # TODO Generator对象的细节 gen_model = Generator(args, data_, graph, level2_parients) gen_model.to(args.device) # TODO generated 对象的细节 generated = Generated_example(gen_model, data_, g_batcher) # 预训练 G模型 pre_train_generator(gen_model, g_batcher, 10) # 利用G 生成一些negative samples generated.generator_train_negative_samples() generated.generator_test_negative_samples() ##################################### ## 第三模块: 创建 D模型,并预训练 D模型 d_model = Discriminator(args, data_) d_batcher = DisBatcher(data_, args) # 预训练 D模型 pre_train_discriminator(d_model, d_batcher, 25) ######################################## ## 第四模块: 交替训练G和D模型 for epoch in range(args.num_epochs): batches = g_batcher.get_batches(mode='train') for step in range(int(len(batches) / 1000)): #训练 G模型 train_generator(gen_model, d_model, g_batcher, d_batcher, batches[step * 1000:(step + 1) * 1000], generated) # 生成训练D的negative samples generated.generator_samples( "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 1000) # 生成测试样本 generated.generator_test_samples() # TODO: 评估 G模型的表现 # 创建训练D的batch(即包含 negative samples和positive samples) d_batcher.train_batch = d_batcher.create_batches(mode='train', shuffleis=True) # 训练 D网络 train_discriminator(d_model, 5, d_batcher, dis_batcher.get_batches(mode="train"))
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity( tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode == "train": os.makedirs(FLAGS.log_root) else: raise Exception( "Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) tf.set_random_seed( 111 ) # a seed value for randomness # train-classification train-sentiment train-cnn-classificatin train-generator if FLAGS.mode == "train-classifier": #print("Start pre-training......") model_class = Classification(hps_discriminator, vocab) cla_batcher = ClaBatcher(hps_discriminator, vocab) sess_cls, saver_cls, train_dir_cls = setup_training_classification( model_class) print("Start pre-training classification......") run_pre_train_classification(model_class, cla_batcher, 1, sess_cls, saver_cls, train_dir_cls) #10 generated = Generate_training_sample(model_class, vocab, cla_batcher, sess_cls) print("Generating training examples......") generated.generate_training_example("train") generated.generate_test_example("test") elif FLAGS.mode == "train-sentimentor": model_class = Classification(hps_discriminator, vocab) cla_batcher = ClaBatcher(hps_discriminator, vocab) sess_cls, saver_cls, train_dir_cls = setup_training_classification( model_class) print("Start pre_train_sentimentor......") model_sentiment = Sentimentor(hps_generator, vocab) sentiment_batcher = SenBatcher(hps_generator, vocab) sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor( model_sentiment) util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification") run_pre_train_sentimentor(model_sentiment, sentiment_batcher, 1, sess_sen, saver_sen, train_dir_sen) #1 elif FLAGS.mode == "test": config = { 'n_epochs': 5, 'kernel_sizes': [3, 4, 5], 'dropout_rate': 0.5, 'val_split': 0.4, 'edim': 300, 'n_words': None, # Leave as none 'std_dev': 0.05, 'sentence_len': 50, 'n_filters': 100, 'batch_size': 50 } config['n_words'] = 50000 cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab) cnn_classifier = CNN(config) sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier( cnn_classifier) #util.load_ckpt(saver_cnn_cls, sess_cnn_cls, ckpt_dir="train-cnnclassification") run_train_cnn_classifier(cnn_classifier, cla_cnn_batcher, 1, sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls) #1 files = os.listdir("test-generate-transfer/") for file_ in files: run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls, "test-generate-transfer/" + file_ + "/*") #elif FLAGS.mode == "test": elif FLAGS.mode == "train-generator": model_class = Classification(hps_discriminator, vocab) cla_batcher = ClaBatcher(hps_discriminator, vocab) sess_cls, saver_cls, train_dir_cls = setup_training_classification( model_class) model_sentiment = Sentimentor(hps_generator, vocab) sentiment_batcher = SenBatcher(hps_generator, vocab) sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor( model_sentiment) config = { 'n_epochs': 5, 'kernel_sizes': [3, 4, 5], 'dropout_rate': 0.5, 'val_split': 0.4, 'edim': 300, 'n_words': None, # Leave as none 'std_dev': 0.05, 'sentence_len': 50, 'n_filters': 100, 'batch_size': 50 } config['n_words'] = 50000 cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab) cnn_classifier = CNN(config) sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier( cnn_classifier) model = Generator(hps_generator, vocab) batcher = GenBatcher(vocab, hps_generator) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) util.load_ckpt(saver_cnn_cls, sess_cnn_cls, ckpt_dir="train-cnnclassification") util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor") generated = Generated_sample(model, vocab, batcher, sess_ge) print("Start pre-training generator......") run_pre_train_generator(model, batcher, 1, sess_ge, saver_ge, train_dir_ge, generated, cla_cnn_batcher, cnn_classifier, sess_cnn_cls) # 4 generated.generate_test_negetive_example( "temp_negetive", batcher) # batcher, model_class, sess_cls, cla_batcher generated.generate_test_positive_example("temp_positive", batcher) #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls, # "temp_negetive" + "/*") loss_window = 0 t0 = time.time() print("begin reinforcement learning:") for epoch in range(30): batches = batcher.get_batches(mode='train') for i in range(len(batches)): current_batch = copy.deepcopy(batches[i]) sentiment_batch = batch_sentiment_batch( current_batch, sentiment_batcher) result = model_sentiment.max_generator(sess_sen, sentiment_batch) weight = result['generated'] current_batch.weight = weight sentiment_batch.weight = weight cla_batch = batch_classification_batch(current_batch, batcher, cla_batcher) result = model_class.run_ypred_auc(sess_cls, cla_batch) cc = SmoothingFunction() reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc']) reward_BLEU = [] for k in range(FLAGS.batch_size): reward_BLEU.append( sentence_bleu( [current_batch.original_reviews[k].split()], cla_batch.original_reviews[k].split(), smoothing_function=cc.method1)) reward_BLEU = np.array(reward_BLEU) reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 / (1e-6 + reward_BLEU))) result = model.run_train_step(sess_ge, current_batch) train_step = result[ 'global_step'] # we need this to update our running average loss loss = result['loss'] loss_window += loss if train_step % 100 == 0: t1 = time.time() tf.logging.info( 'seconds for %d training generator step: %.3f ', train_step, (t1 - t0) / 100) t0 = time.time() tf.logging.info('loss: %f', loss_window / 100) # print the loss to screen loss_window = 0.0 if train_step % 10000 == 0: generated.generate_test_negetive_example( "test-generate-transfer/" + str(epoch) + "epoch_step" + str(train_step) + "_temp_positive", batcher) generated.generate_test_positive_example( "test-generate/" + str(epoch) + "epoch_step" + str(train_step) + "_temp_positive", batcher) #saver_ge.save(sess, train_dir + "/model", global_step=train_step) #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls, # "test-generate-transfer/" + str(epoch) + "epoch_step" + str( # train_step) + "_temp_positive" + "/*") cla_batch, bleu = output_to_classification_batch( result['generated'], current_batch, batcher, cla_batcher, cc) result = model_class.run_ypred_auc(sess_cls, cla_batch) reward_result_sentiment = result['y_pred_auc'] reward_result_bleu = np.array(bleu) reward_result = (2 / (1.0 / (1e-6 + reward_result_sentiment) + 1.0 / (1e-6 + reward_result_bleu))) current_batch.score = 1 - current_batch.score result = model.max_generator(sess_ge, current_batch) cla_batch, bleu = output_to_classification_batch( result['generated'], current_batch, batcher, cla_batcher, cc) result = model_class.run_ypred_auc(sess_cls, cla_batch) reward_result_transfer_sentiment = result['y_pred_auc'] reward_result_transfer_bleu = np.array(bleu) reward_result_transfer = ( 2 / (1.0 / (1e-6 + reward_result_transfer_sentiment) + 1.0 / (1e-6 + reward_result_transfer_bleu))) #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu)) reward = reward_result_transfer #reward_de + reward_result_sentiment + #tf.logging.info("reward_de: "+str(reward_de)) model_sentiment.run_train_step(sess_sen, sentiment_batch, reward)
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode=="train": os.makedirs(FLAGS.log_root) else: raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len'] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data batcher = GenBatcher(vocab, hps_generator) tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'train': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) print("Start pre-training generator......") run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until print("Generating negetive examples......") generated.generator_whole_negetive_example() generated.generator_test_negetive_example() model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "train/generated_samples_positive/*", "train/generated_samples_negetive/*", "test/generated_samples_positive/*", "test/generated_samples_negetive/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis) print("Start pre-training discriminator......") #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test") run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis) util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") generated.generator_sample_example("sample_temp_positive", "sample_temp_negetive", 1000) generated.generator_test_sample_example("test_sample_temp_positive", "test_sample_temp_negetive", 200) generated.generator_test_max_example("test_max_temp_positive", "test_max_temp_negetive", 200) tf.logging.info("true data diversity: ") eva = Evaluate() eva.diversity_evaluate("test_sample_temp_positive" + "/*") print("Start adversial training......") whole_decay = False for epoch in range(1): batches = batcher.get_batches(mode='train') for step in range(int(len(batches)/1000)): run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated): generated.generator_sample_example("sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 1000) #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200) tf.logging.info("test performance: ") tf.logging.info("epoch: "+str(epoch)+" step: "+str(step)) generated.generator_test_sample_example( "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200) generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200) dis_batcher.train_queue = [] dis_batcher.train_queue = [] for i in range(epoch+1): for j in range(step+1): dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*") dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negetive/*") dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True) #dis_batcher.valid_batch = dis_batcher.train_batch whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"), sess_dis, saver_dis, train_dir_dis, whole_decay) '''elif hps_generator.mode == 'decode': decode_model_hps = hps_generator # This will be the hyperparameters for the decoder model model = Generator(decode_model_hps, vocab) generated = Generated_sample(model, vocab, batcher) bleu_score = generated.compute_BLEU()'= tf.logging.info('bleu: %f', bleu_score) # print the loss to screen''' else:
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity( tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if "train" in FLAGS.mode: os.makedirs(FLAGS.log_root) else: raise Exception( "Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # print('FLAGS.flag_values_dict() ->', FLAGS.flag_values_dict()) flags_dict = FLAGS.flag_values_dict() # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps' ] hps_dict = {} for key, val in flags_dict.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict print('hps_dict ->', json.dumps(hps_dict, ensure_ascii=False)) hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len' ] hps_dict = {} for key, val in flags_dict.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # # test # model_dis = Discriminator(hps_discriminator, vocab) # model_dis.build_graph() # sys.exit(0) # # test print('before load batcher...') # Create a batcher object that will create minibatches of data batcher = GenBatcher(vocab, hps_generator) print('after load batcher...') tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'adversarial_train': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator( model_dis) util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator") util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") if not os.path.exists("MLE"): os.mkdir("MLE") print("evaluate the diversity of MLE (decode based on sampling)") generated.generator_test_sample_example("MLE/" + "MLE_sample_positive", "MLE/" + "MLE_sample_negative", 200) print( "evaluate the diversity of MLE (decode based on max probability)") generated.generator_test_max_example("MLE/" + "MLE_max_temp_positive", "MLE/" + "MLE_max_temp_negative", 200) print("Start adversarial training......") if not os.path.exists("train_sample_generated"): os.mkdir("train_sample_generated") if not os.path.exists("test_max_generated"): os.mkdir("test_max_generated") if not os.path.exists("test_sample_generated"): os.mkdir("test_sample_generated") whole_decay = False for epoch in range(10): batches = batcher.get_batches(mode='train') for step in range(int(len(batches) / 1000)): run_train_generator( model, model_dis, sess_dis, batcher, dis_batcher, batches[step * 1000:(step + 1) * 1000], sess_ge, saver_ge, train_dir_ge, generated ) # (model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated): generated.generator_sample_example( "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "train_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 1000) # generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200) tf.logging.info("test performance: ") tf.logging.info("epoch: " + str(epoch) + " step: " + str(step)) print( "evaluate the diversity of DP-GAN (decode based on max probability)" ) generated.generator_test_sample_example( "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200) print( "evaluate the diversity of DP-GAN (decode based on sampling)" ) generated.generator_test_max_example( "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200) dis_batcher.train_queue = [] dis_batcher.train_queue = [] for i in range(epoch + 1): for j in range(step + 1): dis_batcher.train_queue += dis_batcher.fill_example_queue( "train_sample_generated/" + str(i) + "epoch_step" + str(j) + "_temp_positive/*") dis_batcher.train_queue += dis_batcher.fill_example_queue( "train_sample_generated/" + str(i) + "epoch_step" + str(j) + "_temp_negative/*") dis_batcher.train_batch = dis_batcher.create_batches( mode="train", shuffleis=True) # dis_batcher.valid_batch = dis_batcher.train_batch whole_decay = run_train_discriminator( model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"), sess_dis, saver_dis, train_dir_dis, whole_decay) elif hps_generator.mode == 'train_generator': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) print("Start pre-training generator......") # this is an infinite loop until run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge, generated) print("Generating negative examples......") generated.generator_train_negative_example() generated.generator_test_negative_example() elif hps_generator.mode == 'train_discriminator': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator( model_dis) print("Start pre-training discriminator......") # run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test") if not os.path.exists("discriminator_result"): os.mkdir("discriminator_result") run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if "train" in FLAGS.mode: os.makedirs(FLAGS.log_root) else: raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len'] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data batcher = GenBatcher(vocab, hps_generator) tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'adversarial_train': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis) util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator") util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") if not os.path.exists("MLE"): os.mkdir("MLE") print("evaluate the diversity of MLE (decode based on sampling)") generated.generator_test_sample_example("MLE/"+"MLE_sample_positive", "MLE/"+"MLE_sample_negative", 200) print("evaluate the diversity of MLE (decode based on max probability)") generated.generator_test_max_example("MLE/"+"MLE_max_temp_positive", "MLE/"+"MLE_max_temp_negative", 200) print("Start adversarial training......") if not os.path.exists("train_sample_generated"): os.mkdir("train_sample_generated") if not os.path.exists("test_max_generated"): os.mkdir("test_max_generated") if not os.path.exists("test_sample_generated"): os.mkdir("test_sample_generated") whole_decay = False for epoch in range(10): batches = batcher.get_batches(mode='train') for step in range(int(len(batches)/1000)): run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated): generated.generator_sample_example("train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negative", 1000) #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200) tf.logging.info("test performance: ") tf.logging.info("epoch: "+str(epoch)+" step: "+str(step)) print("evaluate the diversity of DP-GAN (decode based on max probability)") generated.generator_test_sample_example( "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200) print("evaluate the diversity of DP-GAN (decode based on sampling)") generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200) dis_batcher.train_queue = [] dis_batcher.train_queue = [] for i in range(epoch+1): for j in range(step+1): dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*") dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negative/*") dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True) #dis_batcher.valid_batch = dis_batcher.train_batch whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"), sess_dis, saver_dis, train_dir_dis, whole_decay) elif hps_generator.mode == 'train_generator': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) print("Start pre-training generator......") run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until print("Generating negative examples......") generated.generator_train_negative_example() generated.generator_test_negative_example() elif hps_generator.mode == 'train_discriminator': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) #util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis) print("Start pre-training discriminator......") #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test") if not os.path.exists("discriminator_result"): os.mkdir("discriminator_result") run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)
def main(): ################################ ## 第一模块:数据准备工作 data_ = data.Data(args.data_dir, args.vocab_size) # 对ICD tree 处理 parient_children, level2_parients, leafNodes, adj, node2id, hier_dicts = utils.build_tree( os.path.join(args.data_dir, 'note_labeled_v2.csv')) graph = utils.generate_graph(parient_children, node2id) args.node2id = node2id args.id2node = {id: node for node, id in node2id.items()} args.adj = torch.Tensor(adj).long().to(args.device) # args.leafNodes=leafNodes args.hier_dicts = hier_dicts # args.level2_parients=level2_parients #print('836:',args.id2node.get(836),args.id2node.get(0)) # TODO batcher对象的细节 g_batcher = GenBatcher(data_, args) ################################# ## 第二模块: 创建G模型,并预训练 G模型 # TODO Generator对象的细节 gen_model_eval = Generator(args, data_, graph, level2_parients) gen_model_target = Generator(args, data_, graph, level2_parients) gen_model_target.eval() print(gen_model_eval) # for name,param in gen_model_eval.named_parameters(): # print(name,param.size(),type(param)) buffer = ReplayBuffer(capacity=100000) gen_model_eval.to(args.device) gen_model_target.to(args.device) # TODO generated 对象的细节 # 预训练 G模型 #pre_train_generator(gen_model,g_batcher,10) ##################################### ## 第三模块: 创建 D模型,并预训练 D模型 d_model = Discriminator(args) d_model.to(args.device) # 预训练 D模型 #pre_train_discriminator(d_model,d_batcher,25) ######################################## ## 第四模块: 交替训练G和D模型 #将评估结果写入文件中 f = open('valid_result.csv', 'w') writer = csv.writer(f) writer.writerow([ 'avg_micro_p', 'avg_macro_p', 'avg_micro_r,avg_macro_r', 'avg_micro_f1', 'avg_macro_f1', 'avg_micro_auc_roc', 'avg_macro_auc_roc' ]) epoch_f = [] for epoch in range(args.num_epochs): batches = g_batcher.get_batches(mode='train') print('number of batches:', len(batches)) for step in range(len(batches)): #print('step:',step) current_batch = batches[step] ehrs = [example.ehr for example in current_batch] ehrs = torch.Tensor(ehrs).long().to(args.device) hier_labels = [example.hier_labels for example in current_batch] true_labels = [] # 对hier_labels进行填充 for i in range(len(hier_labels)): # i为样本索引 for j in range(len(hier_labels[i])): # j为每个样本的每条路径索引 if len(hier_labels[i][j]) < 4: hier_labels[i][j] = hier_labels[i][j] + [0] * ( 4 - len(hier_labels[i][j])) # if len(hier_labels[i]) < args.k: # for time in range(args.k - len(hier_labels[i])): # hier_labels[i].append([0] * args.hops) for sample in hier_labels: #print('sample:',sample) true_labels.append([row[1] for row in sample]) predHierLabels, batchStates_n, batchHiddens_n = generator.generated_negative_samples( gen_model_eval, d_model, ehrs, hier_labels, buffer) #true_labels = [example.labels for example in current_batch] _, _, avgJaccard = full_eval.process_labels( predHierLabels, true_labels, args) # G生成训练D的positive samples batchStates_p, batchHiddens_p = generator.generated_positive_samples( gen_model_eval, ehrs, hier_labels, buffer) # 训练 D网络 #d_loss=train_discriminator(d_model,batchStates_n,batchHiddens_n,batchStates_p,batchHiddens_p,mode=args.mode) # 训练 G模型 #for g_epoch in range(10): g_loss = train_generator(gen_model_eval, gen_model_target, d_model, batchStates_n, batchHiddens_n, buffer, mode=args.mode) print('batch_number:{}, avgJaccard:{:.4f}, g_loss:{:.4f}'.format( step, avgJaccard, g_loss)) # #每经过一个epoch 之后分别评估G 模型的表现以及D模型的表现(在验证集上的表现) avg_micro_f1 = evaluate(g_batcher, gen_model_eval, d_model, buffer, writer, flag='valid') epoch_f.append(avg_micro_f1) # 画图 # plot results window = int(args.num_epochs / 20) print('window:', window) fig, ((ax1), (ax2)) = plt.subplots(2, 1, sharey=True, figsize=[9, 9]) rolling_mean = pd.Series(epoch_f).rolling(window).mean() std = pd.Series(epoch_f).rolling(window).std() ax1.plot(rolling_mean) ax1.fill_between(range(len(epoch_f)), rolling_mean - std, rolling_mean + std, color='orange', alpha=0.2) ax1.set_title( 'Episode Length Moving Average ({}-episode window)'.format(window)) ax1.set_xlabel('Epoch Number') ax1.set_ylabel('F1') ax2.plot(epoch_f) ax2.set_title('Performance on valid set') ax2.set_xlabel('Epoch Number') ax2.set_ylabel('F1') fig.tight_layout(pad=2) plt.show() fig.savefig('results.png') f.close()
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity( tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode == "train": os.makedirs(FLAGS.log_root) else: raise Exception( "Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = [ 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'train': print("Start pre-training......") model_class = Classification(hps_discriminator, vocab) cla_batcher = ClaBatcher(hps_discriminator, vocab) sess_cls, saver_cls, train_dir_cls = setup_training_classification( model_class) print("Start pre-training classification......") #run_pre_train_classification(model_class, cla_batcher, 10, sess_cls, saver_cls, train_dir_cls) #generated = Generate_training_sample(model_class, vocab, cla_batcher, sess_cls) #print("Generating training examples......") #generated.generate_training_example("train") #generated.generator_validation_example("valid") model_sentiment = Sentimentor(hps_generator, vocab) sentiment_batcher = SenBatcher(hps_generator, vocab) sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor( model_sentiment) #run_pre_train_sentimentor(model_sentiment,sentiment_batcher,1,sess_sen,saver_sen,train_dir_sen) sentiment_generated = Generate_non_sentiment_weight( model_sentiment, vocab, sentiment_batcher, sess_sen) #sentiment_generated.generate_training_example("train_sentiment") #sentiment_generated.generator_validation_example("valid_sentiment") model = Generator(hps_generator, vocab) # Create a batcher object that will create minibatches of data batcher = GenBatcher(vocab, hps_generator) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor") util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification") generated = Generated_sample(model, vocab, batcher, sess_ge) #print("Start pre-training generator......") run_pre_train_generator( model, batcher, 4, sess_ge, saver_ge, train_dir_ge, generated, model_class, sess_cls, cla_batcher) # this is an infinite loop until interrupted #generated.generator_validation_negetive_example("temp_negetive", batcher, model_class,sess_cls,cla_batcher) # batcher, model_class, sess_cls, cla_batcher #generated.generator_validation_positive_example( # "temp_positive", batcher, model_class,sess_cls,cla_batcher) loss_window = 0 t0 = time.time() print("begin dual learning:") for epoch in range(30): batches = batcher.get_batches(mode='train') for i in range(len(batches)): current_batch = copy.deepcopy(batches[i]) sentiment_batch = batch_sentiment_batch( current_batch, sentiment_batcher) result = model_sentiment.max_generator(sess_sen, sentiment_batch) weight = result['generated'] current_batch.weight = weight sentiment_batch.weight = weight cla_batch = batch_classification_batch(current_batch, batcher, cla_batcher) result = model_class.run_ypred_auc(sess_cls, cla_batch) cc = SmoothingFunction() reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc']) reward_BLEU = [] for k in range(FLAGS.batch_size): reward_BLEU.append( sentence_bleu( [current_batch.original_reviews[k].split()], cla_batch.original_reviews[k].split(), smoothing_function=cc.method1)) reward_BLEU = np.array(reward_BLEU) reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 / (1e-6 + reward_BLEU))) result = model.run_train_step(sess_ge, current_batch) train_step = result[ 'global_step'] # we need this to update our running average loss loss = result['loss'] loss_window += loss if train_step % 100 == 0: t1 = time.time() tf.logging.info( 'seconds for %d training generator step: %.3f ', train_step, (t1 - t0) / 100) t0 = time.time() tf.logging.info('loss: %f', loss_window / 100) # print the loss to screen loss_window = 0.0 if train_step % 10000 == 0: #bleu_score = generatored.compute_BLEU(str(train_step)) #tf.logging.info('bleu: %f', bleu_score) # print the loss to screen generated.generator_validation_negetive_example( "valid-generated-transfer/" + str(epoch) + "epoch_step" + str(train_step) + "_temp_positive", batcher, model_class, sess_cls, cla_batcher) generated.generator_validation_positive_example( "valid-generated/" + str(epoch) + "epoch_step" + str(train_step) + "_temp_positive", batcher, model_class, sess_cls, cla_batcher) #saver_ge.save(sess, train_dir + "/model", global_step=train_step) cla_batch, bleu = output_to_classification_batch( result['generated'], current_batch, batcher, cla_batcher, cc) result = model_class.run_ypred_auc(sess_cls, cla_batch) reward_result_sentiment = result['y_pred_auc'] reward_result_bleu = np.array(bleu) reward_result = (2 / (1.0 / (1e-6 + reward_result_sentiment) + 1.0 / (1e-6 + reward_result_bleu))) current_batch.score = 1 - current_batch.score result = model.max_generator(sess_ge, current_batch) cla_batch, bleu = output_to_classification_batch( result['generated'], current_batch, batcher, cla_batcher, cc) result = model_class.run_ypred_auc(sess_cls, cla_batch) reward_result_transfer_sentiment = result['y_pred_auc'] reward_result_transfer_bleu = np.array(bleu) reward_result_transfer = ( 2 / (1.0 / (1e-6 + reward_result_transfer_sentiment) + 1.0 / (1e-6 + reward_result_transfer_bleu))) #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu)) reward = reward_result_transfer #reward_de + reward_result_sentiment + #tf.logging.info("reward_de: "+str(reward_de)) model_sentiment.run_train_step(sess_sen, sentiment_batch, reward) elif hps_generator.mode == 'decode': decode_model_hps = hps_generator # This will be the hyperparameters for the decoder model #model = Generator(decode_model_hps, vocab) #generated = Generated_sample(model, vocab, batcher) #bleu_score = generated.compute_BLEU() #tf.logging.info('bleu: %f', bleu_score) # print the loss to screen else: raise ValueError("The 'mode' flag must be one of train/eval/decode")
def main(argv): tf.set_random_seed(111) # a seed value for randomness # Create a batcher object that will create minibatches of data # TODO change to pass number # --------------- building graph --------------- hparam_gen = [ 'mode', 'model_dir', 'adagrad_init_acc', 'steps_per_checkpoint', 'batch_size', 'beam_size', 'cov_loss_wt', 'coverage', 'emb_dim', 'rand_unif_init_mag', 'gen_vocab_file', 'gen_vocab_size', 'hidden_dim', 'gen_lr', 'gen_max_gradient', 'max_dec_steps', 'max_enc_steps', 'min_dec_steps', 'trunc_norm_init_std', 'single_pass', 'log_root', 'data_path', ] hps_dict = {} for key, val in FLAGS.__flags.iteritems(): # for each flag if key in hparam_gen: # if it's in the list hps_dict[key] = val # add it to the dict hps_gen = namedtuple("HParams4Gen", hps_dict.keys())(**hps_dict) print("Building vocabulary for generator ...") gen_vocab = Vocab(join_path(hps_gen.data_path, hps_gen.gen_vocab_file), hps_gen.gen_vocab_size) hparam_dis = [ 'mode', 'vocab_type', 'model_dir', 'dis_vocab_size', 'steps_per_checkpoint', 'learning_rate_decay_factor', 'dis_vocab_file', 'num_class', 'layer_size', 'conv_layers', 'max_steps', 'kernel_size', 'early_stop', 'pool_size', 'pool_layers', 'dis_max_gradient', 'batch_size', 'dis_lr', 'lr_decay_factor', 'cell_type', 'max_enc_steps', 'max_dec_steps', 'single_pass', 'data_path', 'num_models', ] hps_dict = {} for key, val in FLAGS.__flags.iteritems(): # for each flag if key in hparam_dis: # if it's in the list hps_dict[key] = val # add it to the dict hps_dis = namedtuple("HParams4Dis", hps_dict.keys())(**hps_dict) if hps_gen.gen_vocab_file == hps_dis.dis_vocab_file: hps_dis = hps_dis._replace(vocab_type="word") hps_dis = hps_dis._replace(layer_size=hps_gen.emb_dim) hps_dis = hps_dis._replace(dis_vocab_size=hps_gen.gen_vocab_size) else: hps_dis = hps_dis._replace(max_enc_steps=hps_dis.max_enc_steps * 2) hps_dis = hps_dis._replace(max_dec_steps=hps_dis.max_dec_steps * 2) if FLAGS.mode == "train_gan": hps_gen = hps_gen._replace(batch_size=hps_gen.batch_size * hps_dis.num_models) if FLAGS.mode != "pretrain_dis": with tf.variable_scope("generator"): generator = PointerGenerator(hps_gen, gen_vocab) print("Building generator graph ...") gen_decoder_scope = generator.build_graph() if FLAGS.mode != "pretrain_gen": print("Building vocabulary for discriminator ...") dis_vocab = Vocab(join_path(hps_dis.data_path, hps_dis.dis_vocab_file), hps_dis.dis_vocab_size) if FLAGS.mode in ['train_gan', 'pretrain_dis']: with tf.variable_scope("discriminator"), tf.device("/gpu:0"): discriminator = Seq2ClassModel(hps_dis) print("Building discriminator graph ...") discriminator.build_graph() hparam_gan = [ 'mode', 'model_dir', 'gan_iter', 'gan_gen_iter', 'gan_dis_iter', 'gan_lr', 'rollout_num', 'sample_num', ] hps_dict = {} for key, val in FLAGS.__flags.iteritems(): # for each flag if key in hparam_gan: # if it's in the list hps_dict[key] = val # add it to the dict hps_gan = namedtuple("HParams4GAN", hps_dict.keys())(**hps_dict) hps_gan = hps_gan._replace(mode="train_gan") if FLAGS.mode == 'train_gan': with tf.device("/gpu:0"): print("Creating rollout...") rollout = Rollout(generator, 0.8, gen_decoder_scope) # --------------- initializing variables --------------- all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) + \ tf.get_collection_ref(tf.GraphKeys.WEIGHTS) + \ tf.get_collection_ref(tf.GraphKeys.BIASES) sess = tf.Session(config=utils.get_config()) sess.run(tf.variables_initializer(all_variables)) if FLAGS.mode == "pretrain_gen": val_dir = ensure_exists( join_path(FLAGS.model_dir, 'generator', FLAGS.val_dir)) model_dir = ensure_exists(join_path(FLAGS.model_dir, 'generator')) print("Restoring the generator model from the latest checkpoint...") gen_saver = tf.train.Saver( max_to_keep=3, var_list=[ v for v in all_variables if "generator" in v.name and "GAN" not in v.name ]) gen_dir = ensure_exists(join_path(FLAGS.model_dir, "generator")) # gen_dir = ensure_exists(FLAGS.model_dir) # temp_saver = tf.train.Saver( # var_list=[v for v in all_variables if "generator" in v.name and "Adagrad" not in v.name]) ckpt_path = utils.load_ckpt(gen_saver, sess, gen_dir) print('going to restore embeddings from checkpoint') if not ckpt_path: emb_path = join_path(FLAGS.model_dir, "generator", "init_embed") if emb_path: generator.saver.restore( sess, tf.train.get_checkpoint_state( emb_path).model_checkpoint_path) print( colored( "successfully restored embeddings form %s" % emb_path, 'green')) else: print( colored("failed to restore embeddings form %s" % emb_path, 'red')) elif FLAGS.mode in ["decode", "train_gan"]: print("Restoring the generator model from the best checkpoint...") dec_saver = tf.train.Saver( max_to_keep=3, var_list=[v for v in all_variables if "generator" in v.name]) gan_dir = ensure_exists( join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir)) gan_val_dir = ensure_exists( join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir, FLAGS.val_dir)) gan_saver = tf.train.Saver( max_to_keep=3, var_list=[v for v in all_variables if "generator" in v.name]) gan_val_saver = tf.train.Saver( max_to_keep=3, var_list=[v for v in all_variables if "generator" in v.name]) utils.load_ckpt(dec_saver, sess, val_dir, (FLAGS.mode in ["train_gan", "decode"])) if FLAGS.mode in ["pretrain_dis", "train_gan"]: dis_saver = tf.train.Saver( max_to_keep=3, var_list=[v for v in all_variables if "discriminator" in v.name]) dis_dir = ensure_exists(join_path(FLAGS.model_dir, 'discriminator')) ckpt = utils.load_ckpt(dis_saver, sess, dis_dir) if not ckpt: if hps_dis.vocab_type == "word": discriminator.init_emb( sess, join_path(FLAGS.model_dir, "generator", "init_embed")) else: discriminator.init_emb( sess, join_path(FLAGS.model_dir, "discriminator", "init_embed")) # --------------- train models --------------- if FLAGS.mode != "pretrain_dis": gen_batcher_train = GenBatcher("train", gen_vocab, hps_gen, single_pass=hps_gen.single_pass) decoder = Decoder(sess, generator, gen_vocab) gen_batcher_val = GenBatcher("val", gen_vocab, hps_gen, single_pass=True) val_saver = tf.train.Saver( max_to_keep=10, var_list=[ v for v in all_variables if "generator" in v.name and "GAN" not in v.name ]) if FLAGS.mode != "pretrain_gen": dis_val_batch_size = hps_dis.batch_size * hps_dis.num_models \ if hps_dis.mode == "train_gan" else hps_dis.batch_size * hps_dis.num_models * 2 dis_batcher_val = DisBatcher( hps_dis.data_path, "eval", gen_vocab, dis_vocab, dis_val_batch_size, single_pass=True, max_art_steps=hps_dis.max_enc_steps, max_abs_steps=hps_dis.max_dec_steps, ) if FLAGS.mode == "pretrain_gen": # get reload the print('Going to pretrain the generator') try: pretrain_generator(generator, gen_batcher_train, sess, gen_batcher_val, gen_saver, model_dir, val_saver, val_dir) except KeyboardInterrupt: tf.logging.info("Caught keyboard interrupt on worker....") elif FLAGS.mode == "pretrain_dis": print('Going to pretrain the discriminator') dis_batcher = DisBatcher( hps_dis.data_path, "decode", gen_vocab, dis_vocab, hps_dis.batch_size * hps_dis.num_models, single_pass=hps_dis.single_pass, max_art_steps=hps_dis.max_enc_steps, max_abs_steps=hps_dis.max_dec_steps, ) try: pretrain_discriminator(sess, discriminator, dis_batcher_val, dis_vocab, dis_batcher, dis_saver) except KeyboardInterrupt: tf.logging.info("Caught keyboard interrupt on worker....") elif FLAGS.mode == "train_gan": gen_best_loss = get_best_loss_from_chpt(val_dir) gen_global_step = 0 print('Going to tune the two using Gan') for i_gan in range(hps_gan.gan_iter): # Train the generator for one step g_losses = [] current_speed = [] for it in range(hps_gan.gan_gen_iter): start_time = time.time() batch = gen_batcher_train.next_batch() # generate samples enc_states, dec_in_state, n_samples, n_targets_padding_mask = decoder.mc_generate( batch, include_start_token=True, s_num=hps_gan.sample_num) # get rewards for the samples n_rewards = rollout.get_reward(sess, gen_vocab, dis_vocab, batch, enc_states, dec_in_state, n_samples, hps_gan.rollout_num, discriminator) # fine tune the generator n_sample_targets = [samples[:, 1:] for samples in n_samples] n_targets_padding_mask = [ padding_mask[:, 1:] for padding_mask in n_targets_padding_mask ] n_samples = [samples[:, :-1] for samples in n_samples] # sample_target_padding_mask = pad_sample(sample_target, gen_vocab, hps_gen) n_samples = [ np.where( np.less(samples, hps_gen.gen_vocab_size), samples, np.array([[gen_vocab.word2id(data.UNKNOWN_TOKEN)] * hps_gen.max_dec_steps] * hps_gen.batch_size)) for samples in n_samples ] results = generator.run_gan_batch(sess, batch, n_samples, n_sample_targets, n_targets_padding_mask, n_rewards) gen_global_step = results["global_step"] # for visualization g_loss = results["loss"] if not math.isnan(g_loss): g_losses.append(g_loss) else: print(colored('a nan in gan loss', 'red')) current_speed.append(time.time() - start_time) # Test # if FLAGS.gan_gen_iter and (i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1): if i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1: print('Going to test the generator.') current_speed = sum(current_speed) / (len(current_speed) * hps_gen.batch_size) everage_g_loss = sum(g_losses) / len(g_losses) # one more process hould be opened for the evaluation eval_loss, gen_best_loss = save_ckpt( sess, generator, gen_best_loss, gan_dir, gan_saver, gen_batcher_val, gan_val_dir, gan_val_saver, gen_global_step) if eval_loss: print("\nDashboard for " + colored("GAN Generator", 'green') + " updated %s, " "finished steps:\t%s\n" "\tBatch size:\t%s\n" "\tVocabulary size:\t%s\n" "\tCurrent speed:\t%.4f seconds/article\n" "\tAverage training loss:\t%.4f; " "eval loss:\t%.4f" % ( datetime.datetime.now().strftime( "on %m-%d at %H:%M"), gen_global_step, FLAGS.batch_size, hps_gen.gen_vocab_size, current_speed, everage_g_loss.item(), eval_loss.item(), )) # Train the discriminator print('Going to train the discriminator.') dis_best_loss = 1000 dis_losses = [] dis_accuracies = [] for d_gan in range(hps_gan.gan_dis_iter): batch = gen_batcher_train.next_batch() enc_states, dec_in_state, k_samples_words, _ = decoder.mc_generate( batch, s_num=hps_gan.sample_num) # shuould first tanslate to words to avoid unk articles_oovs = batch.art_oovs for samples_words in k_samples_words: dec_batch_words = batch.target_batch conditions_words = batch.enc_batch_extend_vocab if hps_dis.vocab_type == "char": samples = gen_vocab2dis_vocab(samples_words, gen_vocab, articles_oovs, dis_vocab, hps_dis.max_dec_steps, STOP_DECODING) dec_batch = gen_vocab2dis_vocab( dec_batch_words, gen_vocab, articles_oovs, dis_vocab, hps_dis.max_dec_steps, STOP_DECODING) conditions = gen_vocab2dis_vocab( conditions_words, gen_vocab, articles_oovs, dis_vocab, hps_dis.max_enc_steps, PAD_TOKEN) else: samples = samples_words dec_batch = dec_batch_words conditions = conditions_words # the unknown in target inputs = np.concatenate([samples, dec_batch], 0) conditions = np.concatenate([conditions, conditions], 0) targets = [[1, 0] for _ in samples] + [[0, 1] for _ in dec_batch] targets = np.array(targets) # randomize the samples assert len(inputs) == len(conditions) == len( targets ), "lengthes of the inputs, conditions and targests should be the same." indices = np.random.permutation(len(inputs)) inputs = np.split(inputs[indices], 2) conditions = np.split(conditions[indices], 2) targets = np.split(targets[indices], 2) assert len(inputs) % 2 == 0, "the length should be mean" results = discriminator.run_one_batch( sess, inputs[0], conditions[0], targets[0]) dis_accuracies.append(results["accuracy"].item()) dis_losses.append(results["loss"].item()) results = discriminator.run_one_batch( sess, inputs[1], conditions[1], targets[1]) dis_accuracies.append(results["accuracy"].item()) ave_dis_acc = sum(dis_accuracies) / len(dis_accuracies) if d_gan == hps_gan.gan_dis_iter - 1: if (sum(dis_losses) / len(dis_losses)) < dis_best_loss: dis_best_loss = sum(dis_losses) / len(dis_losses) checkpoint_path = ensure_exists( join_path(hps_dis.model_dir, "discriminator")) + "/model.ckpt" dis_saver.save(sess, checkpoint_path, global_step=results["global_step"]) print_dashboard("GAN Discriminator", results["global_step"].item(), hps_dis.batch_size, hps_dis.dis_vocab_size, results["loss"].item(), 0.00, 0.00, 0.00) print("Average training accuracy: \t%.4f" % ave_dis_acc) if ave_dis_acc > 0.9: break # --------------- decoding samples --------------- elif FLAGS.mode == "decode": print('Going to decode from the generator.') decoder.bs_decode(gen_batcher_train) print("Finished decoding..") # decode for generating corpus for discriminator sess.close()