def generator_sample_example(self, positive_dir, negative_dir, num_batch): self.temp_positive_dir = positive_dir self.temp_negative_dir = negative_dir if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir) shutil.rmtree(self.temp_negative_dir) shutil.rmtree(self.temp_positive_dir) if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir) counter = 0 for i in range(num_batch): decode_result = self._model.run_eval_given_step(self._sess, self.batches[self.current_batch]) for i in range(FLAGS.batch_size): decoded_words_all = [] original_review = self.batches[self.current_batch].original_review_output[i] for j in range(FLAGS.max_dec_sen_num): output_ids = [int(t) for t in decode_result['generated'][i][j]][1:] decoded_words = data.outputids2words(output_ids, self._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if len(decoded_words)<2: continue if len(decoded_words_all)>0: new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split()) new_set2= set(decoded_words) if len(new_set1 & new_set2) > 0.5 * len(new_set2): continue decoded_output = ' '.join(decoded_words).strip() # single string decoded_words_all.append(decoded_output) decoded_words_all = ' '.join(decoded_words_all).strip() try: fst_stop_idx = decoded_words_all.index( data.STOP_DECODING_DOCUMENT) # index of the (first) [STOP] symbol decoded_words_all = decoded_words_all[:fst_stop_idx] except ValueError: decoded_words_all = decoded_words_all decoded_words_all = decoded_words_all.replace("[UNK] ", "") decoded_words_all = decoded_words_all.replace("[UNK]", "") decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all) decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all) self.write_negtive_temp_to_json(original_review, decoded_words_all, counter) counter += 1 # this is how many examples we've decoded self.current_batch +=1 if self.current_batch >= len(self.batches): self.current_batch = 0 eva = Evaluate() eva.diversity_evaluate(negative_dir + "/*")
def generator_test_max_example(self, positive_dir, negative_dir, num_batch): self.temp_positive_dir = positive_dir self.temp_negative_dir = negative_dir if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir) shutil.rmtree(self.temp_negative_dir) shutil.rmtree(self.temp_positive_dir) if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir) counter = 0 batches = self.test_batches step = 0 list_hop = [] list_ref = [] while step < num_batch: batch = batches[step] step += 1 decode_result = self._model.max_generator(self._sess, batch) for i in range(FLAGS.batch_size): decoded_words_all = [] original_review = batch.original_review_output[i] for j in range(FLAGS.max_dec_sen_num): output_ids = [ int(t) for t in decode_result['generated'][i][j] ][1:] decoded_words = data.outputids2words( output_ids, self._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index( data.STOP_DECODING ) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if len(decoded_words) < 2: continue if len(decoded_words_all) > 0: new_set1 = set( decoded_words_all[len(decoded_words_all) - 1].split()) new_set2 = set(decoded_words) if len(new_set1 & new_set2) > 0.5 * len(new_set2): continue decoded_output = ' '.join( decoded_words).strip() # single string decoded_words_all.append(decoded_output) decoded_words_all = ' '.join(decoded_words_all).strip() try: fst_stop_idx = decoded_words_all.index( data.STOP_DECODING_DOCUMENT ) # index of the (first) [STOP] symbol decoded_words_all = decoded_words_all[:fst_stop_idx] except ValueError: decoded_words_all = decoded_words_all decoded_words_all = decoded_words_all.replace("[UNK] ", "") decoded_words_all = decoded_words_all.replace("[UNK]", "") decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all) decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all) self.write_negtive_temp_to_json( batch.original_review_inputs[i], original_review, decoded_words_all, counter) list_ref.append([nltk.word_tokenize(original_review)]) list_hop.append(nltk.word_tokenize(decoded_words_all)) counter += 1 # this is how many examples we've decoded # bleu_score = corpus_bleu(list_ref, list_hop) # tf.logging.info('bleu: ' + str(bleu_score)) eva = Evaluate() eva.diversity_evaluate(negative_dir + "/*")
def generator_max_example(self, target_batches, positive_dir, negetive_dir): self.temp_positive_dir = positive_dir self.temp_negetive_dir = negetive_dir if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negetive_dir): os.mkdir(self.temp_negetive_dir) shutil.rmtree(self.temp_negetive_dir) shutil.rmtree(self.temp_positive_dir) if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negetive_dir): os.mkdir(self.temp_negetive_dir) counter = 0 batches = target_batches step = 0 while step < len(target_batches): batch = copy.deepcopy(batches[step]) step += 1 decoded_words_all = [[] for i in range(FLAGS.batch_size)] original_reviews = [ " ".join(review) for review in batch.orginal_all_text ] #tf.logging.info(batch.enc_lens) for k in range(FLAGS.max_dec_sen_num): decode_result = self._model.max_generator(self._sess, batch) srl_batch, decode_mask = self.output_to_batch( batch, decode_result) decode_result_seq = self._srl_model.max_generator( self._sess_srl, srl_batch) if k < FLAGS.max_dec_sen_num - 1: batch = self.seq_output_to_batch(decode_result_seq, batch) for i in range(FLAGS.batch_size): if decode_mask[i] == 0: decoded_words_all[i].append( data.STOP_DECODING_DOCUMENT) else: output_ids = [ int(t) for t in decode_result_seq['generated'][i] ][0:] decoded_words = data.outputids2words( output_ids, self._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index( data.STOP_DECODING ) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if len(decoded_words) < 1: continue if len(decoded_words_all[i]) > 0: new_set1 = set( decoded_words_all[i][len(decoded_words_all[i]) - 1].split()) new_set2 = set(decoded_words) if len(new_set1 & new_set2) > 0.5 * len(new_set2): continue decoded_output = ' '.join( decoded_words).strip() # single string decoded_words_all[i].append(decoded_output) for i in range(FLAGS.batch_size): batch_seq = ' '.join([ decoded_words_all[i][j] for j in range(len(decoded_words_all[i])) ]).strip() try: fst_stop_idx = batch_seq.index( data.STOP_DECODING_DOCUMENT ) # index of the (first) [STOP] symbol batch_seq = batch_seq[:fst_stop_idx] except ValueError: batch_seq = batch_seq batch_seq = batch_seq.replace("[UNK] ", "") batch_seq = batch_seq.replace("[UNK]", "") batch_seq, _ = re.subn(r"(! ){2,}", "! ", batch_seq) batch_seq, _ = re.subn(r"(\. ){2,}", ". ", batch_seq) self.write_negtive_temp_to_json(positive_dir, negetive_dir, original_reviews[i], batch_seq) eva = Evaluate() eva.diversity_evaluate(negetive_dir + "/*")
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting running in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode=="train": os.makedirs(FLAGS.log_root) else: raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict) hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len'] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data batcher = GenBatcher(vocab, hps_generator) tf.set_random_seed(111) # a seed value for randomness if hps_generator.mode == 'train': print("Start pre-training......") model = Generator(hps_generator, vocab) sess_ge, saver_ge, train_dir_ge = setup_training_generator(model) generated = Generated_sample(model, vocab, batcher, sess_ge) print("Start pre-training generator......") run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until print("Generating negetive examples......") generated.generator_whole_negetive_example() generated.generator_test_negetive_example() model_dis = Discriminator(hps_discriminator, vocab) dis_batcher = DisBatcher(hps_discriminator, vocab, "train/generated_samples_positive/*", "train/generated_samples_negetive/*", "test/generated_samples_positive/*", "test/generated_samples_negetive/*") sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis) print("Start pre-training discriminator......") #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test") run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis) util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator") generated.generator_sample_example("sample_temp_positive", "sample_temp_negetive", 1000) generated.generator_test_sample_example("test_sample_temp_positive", "test_sample_temp_negetive", 200) generated.generator_test_max_example("test_max_temp_positive", "test_max_temp_negetive", 200) tf.logging.info("true data diversity: ") eva = Evaluate() eva.diversity_evaluate("test_sample_temp_positive" + "/*") print("Start adversial training......") whole_decay = False for epoch in range(1): batches = batcher.get_batches(mode='train') for step in range(int(len(batches)/1000)): run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated): generated.generator_sample_example("sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 1000) #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200) tf.logging.info("test performance: ") tf.logging.info("epoch: "+str(epoch)+" step: "+str(step)) generated.generator_test_sample_example( "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200) generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive", "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200) dis_batcher.train_queue = [] dis_batcher.train_queue = [] for i in range(epoch+1): for j in range(step+1): dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*") dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negetive/*") dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True) #dis_batcher.valid_batch = dis_batcher.train_batch whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"), sess_dis, saver_dis, train_dir_dis, whole_decay) '''elif hps_generator.mode == 'decode': decode_model_hps = hps_generator # This will be the hyperparameters for the decoder model model = Generator(decode_model_hps, vocab) generated = Generated_sample(model, vocab, batcher) bleu_score = generated.compute_BLEU()'= tf.logging.info('bleu: %f', bleu_score) # print the loss to screen''' else:
def generator_max_example(self, target_batches, positive_dir, negetive_dir): self.temp_positive_dir = positive_dir self.temp_negetive_dir = negetive_dir if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negetive_dir): os.mkdir(self.temp_negetive_dir) shutil.rmtree(self.temp_negetive_dir) shutil.rmtree(self.temp_positive_dir) if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir) if not os.path.exists(self.temp_negetive_dir): os.mkdir(self.temp_negetive_dir) counter = 0 batches = target_batches step = 0 while step < len(target_batches): batch = batches[step] step += 1 decode_result = self._model.max_generator(self._sess, batch) for i in range(FLAGS.batch_size): original_review = batch.orig_outputs[i] output_ids = [int(t) for t in decode_result['generated'][i]][0:] decoded_words = data.outputids2words(output_ids, self._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index( data.STOP_DECODING ) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words decoded_output = ' '.join( decoded_words).strip() # single string try: fst_stop_idx = decoded_output.index( data.STOP_DECODING_DOCUMENT ) # index of the (first) [STOP] symbol decoded_output = decoded_output[:fst_stop_idx] except ValueError: decoded_output = decoded_output decoded_output = decoded_output.replace("[UNK] ", "") decoded_output = decoded_output.replace("[UNK]", "") decoded_output, _ = re.subn(r"(! ){2,}", "! ", decoded_output) decoded_output, _ = re.subn(r"(\. ){2,}", ". ", decoded_output) self.write_negtive_temp_to_json(positive_dir, negetive_dir, original_review, decoded_output) eva = Evaluate() eva.diversity_evaluate(negetive_dir + "/*")