def generator_sample_example(self, positive_dir, negative_dir, num_batch):
        self.temp_positive_dir = positive_dir
        self.temp_negative_dir = negative_dir

        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        shutil.rmtree(self.temp_negative_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        counter = 0

        for i in range(num_batch):
            decode_result = self._model.run_eval_given_step(self._sess, self.batches[self.current_batch])
            for i in range(FLAGS.batch_size):
                decoded_words_all = []
                original_review = self.batches[self.current_batch].original_review_output[i]
                for j in range(FLAGS.max_dec_sen_num):
                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words
                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)

                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all)
                self.write_negtive_temp_to_json(original_review, decoded_words_all, counter)
                counter += 1  # this is how many examples we've decoded
            self.current_batch +=1
            if self.current_batch >= len(self.batches):
                self.current_batch = 0
        
        eva = Evaluate()
        eva.diversity_evaluate(negative_dir + "/*")
示例#2
0
    def generator_sample_example(self, positive_dir, negative_dir, num_batch):

        self.temp_positive_dir = positive_dir
        self.temp_negative_dir = negative_dir

        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        shutil.rmtree(self.temp_negative_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir): os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir): os.mkdir(self.temp_negative_dir)
        counter = 0


        for i in range(num_batch):
            decode_result = self._model.run_eval_given_step(self._sess, self.batches[self.current_batch])


            for i in range(FLAGS.batch_size):

                decoded_words_all = []
                original_review = self.batches[self.current_batch].original_review_output[i]

                for j in range(FLAGS.max_dec_sen_num):

                    output_ids = [int(t) for t in decode_result['generated'][i][j]][1:]
                    decoded_words = data.outputids2words(output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words)<2:
                        continue

                    if len(decoded_words_all)>0:
                        new_set1 =set(decoded_words_all[len(decoded_words_all)-1].split())
                        new_set2= set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    decoded_output = ' '.join(decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT)  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "! ", decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ", decoded_words_all)
                self.write_negtive_temp_to_json(original_review, decoded_words_all, counter)

                counter += 1  # this is how many examples we've decoded
            self.current_batch +=1
            if self.current_batch >= len(self.batches):
                self.current_batch = 0
        
        eva = Evaluate()
        eva.diversity_evaluate(negative_dir + "/*")
    def generator_test_max_example(self, positive_dir, negative_dir,
                                   num_batch):
        self.temp_positive_dir = positive_dir
        self.temp_negative_dir = negative_dir

        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir):
            os.mkdir(self.temp_negative_dir)
        shutil.rmtree(self.temp_negative_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negative_dir):
            os.mkdir(self.temp_negative_dir)

        counter = 0
        batches = self.test_batches
        step = 0
        list_hop = []
        list_ref = []
        while step < num_batch:
            batch = batches[step]
            step += 1
            decode_result = self._model.max_generator(self._sess, batch)

            for i in range(FLAGS.batch_size):

                decoded_words_all = []
                original_review = batch.original_review_output[i]

                for j in range(FLAGS.max_dec_sen_num):

                    output_ids = [
                        int(t) for t in decode_result['generated'][i][j]
                    ][1:]
                    decoded_words = data.outputids2words(
                        output_ids, self._vocab, None)
                    # Remove the [STOP] token from decoded_words, if necessary
                    try:
                        fst_stop_idx = decoded_words.index(
                            data.STOP_DECODING
                        )  # index of the (first) [STOP] symbol
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if len(decoded_words) < 2:
                        continue

                    if len(decoded_words_all) > 0:
                        new_set1 = set(
                            decoded_words_all[len(decoded_words_all) -
                                              1].split())
                        new_set2 = set(decoded_words)
                        if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                            continue
                    decoded_output = ' '.join(
                        decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                decoded_words_all = ' '.join(decoded_words_all).strip()
                try:
                    fst_stop_idx = decoded_words_all.index(
                        data.STOP_DECODING_DOCUMENT
                    )  # index of the (first) [STOP] symbol
                    decoded_words_all = decoded_words_all[:fst_stop_idx]
                except ValueError:
                    decoded_words_all = decoded_words_all
                decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                decoded_words_all = decoded_words_all.replace("[UNK]", "")
                decoded_words_all, _ = re.subn(r"(! ){2,}", "! ",
                                               decoded_words_all)
                decoded_words_all, _ = re.subn(r"(\. ){2,}", ". ",
                                               decoded_words_all)
                self.write_negtive_temp_to_json(
                    batch.original_review_inputs[i], original_review,
                    decoded_words_all, counter)
                list_ref.append([nltk.word_tokenize(original_review)])
                list_hop.append(nltk.word_tokenize(decoded_words_all))

                counter += 1  # this is how many examples we've decoded

        # bleu_score = corpus_bleu(list_ref, list_hop)
        # tf.logging.info('bleu: '  + str(bleu_score))
        eva = Evaluate()
        eva.diversity_evaluate(negative_dir + "/*")
示例#4
0
    def generator_max_example(self, target_batches, positive_dir,
                              negetive_dir):

        self.temp_positive_dir = positive_dir
        self.temp_negetive_dir = negetive_dir

        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negetive_dir):
            os.mkdir(self.temp_negetive_dir)
        shutil.rmtree(self.temp_negetive_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negetive_dir):
            os.mkdir(self.temp_negetive_dir)
        counter = 0
        batches = target_batches
        step = 0

        while step < len(target_batches):

            batch = copy.deepcopy(batches[step])
            step += 1
            decoded_words_all = [[] for i in range(FLAGS.batch_size)]
            original_reviews = [
                " ".join(review) for review in batch.orginal_all_text
            ]
            #tf.logging.info(batch.enc_lens)

            for k in range(FLAGS.max_dec_sen_num):

                decode_result = self._model.max_generator(self._sess, batch)

                srl_batch, decode_mask = self.output_to_batch(
                    batch, decode_result)
                decode_result_seq = self._srl_model.max_generator(
                    self._sess_srl, srl_batch)
                if k < FLAGS.max_dec_sen_num - 1:

                    batch = self.seq_output_to_batch(decode_result_seq, batch)

                for i in range(FLAGS.batch_size):

                    if decode_mask[i] == 0:
                        decoded_words_all[i].append(
                            data.STOP_DECODING_DOCUMENT)
                    else:
                        output_ids = [
                            int(t) for t in decode_result_seq['generated'][i]
                        ][0:]
                        decoded_words = data.outputids2words(
                            output_ids, self._vocab, None)
                        # Remove the [STOP] token from decoded_words, if necessary
                        try:
                            fst_stop_idx = decoded_words.index(
                                data.STOP_DECODING
                            )  # index of the (first) [STOP] symbol
                            decoded_words = decoded_words[:fst_stop_idx]
                        except ValueError:
                            decoded_words = decoded_words

                        if len(decoded_words) < 1:
                            continue

                        if len(decoded_words_all[i]) > 0:
                            new_set1 = set(
                                decoded_words_all[i][len(decoded_words_all[i])
                                                     - 1].split())
                            new_set2 = set(decoded_words)
                            if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                                continue
                        decoded_output = ' '.join(
                            decoded_words).strip()  # single string
                        decoded_words_all[i].append(decoded_output)

            for i in range(FLAGS.batch_size):
                batch_seq = ' '.join([
                    decoded_words_all[i][j]
                    for j in range(len(decoded_words_all[i]))
                ]).strip()
                try:
                    fst_stop_idx = batch_seq.index(
                        data.STOP_DECODING_DOCUMENT
                    )  # index of the (first) [STOP] symbol
                    batch_seq = batch_seq[:fst_stop_idx]
                except ValueError:
                    batch_seq = batch_seq
                batch_seq = batch_seq.replace("[UNK] ", "")
                batch_seq = batch_seq.replace("[UNK]", "")
                batch_seq, _ = re.subn(r"(! ){2,}", "! ", batch_seq)
                batch_seq, _ = re.subn(r"(\. ){2,}", ". ", batch_seq)
                self.write_negtive_temp_to_json(positive_dir, negetive_dir,
                                                original_reviews[i], batch_seq)

        eva = Evaluate()
        eva.diversity_evaluate(negetive_dir + "/*")
示例#5
0
文件: main.py 项目: bluepine/DPGAN
def main(unused_argv):
  if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
    raise Exception("Problem with flags: %s" % unused_argv)

  tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
  tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

  # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
  FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
  if not os.path.exists(FLAGS.log_root):
    if FLAGS.mode=="train":
      os.makedirs(FLAGS.log_root)
    else:
      raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

  vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary


  # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
  hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps']
  hps_dict = {}
  for key,val in FLAGS.__flags.items(): # for each flag
    if key in hparam_list: # if it's in the list
      hps_dict[key] = val # add it to the dict
  hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
                 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len']
  hps_dict = {}
  for key, val in FLAGS.__flags.items():  # for each flag
      if key in hparam_list:  # if it's in the list
          hps_dict[key] = val  # add it to the dict
  hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  # Create a batcher object that will create minibatches of data
  batcher = GenBatcher(vocab, hps_generator)




  tf.set_random_seed(111) # a seed value for randomness





  if hps_generator.mode == 'train':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)
    print("Start pre-training generator......")
    run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until 

    print("Generating negetive examples......")
    generated.generator_whole_negetive_example()
    generated.generator_test_negetive_example()

    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "train/generated_samples_positive/*", "train/generated_samples_negetive/*", "test/generated_samples_positive/*", "test/generated_samples_negetive/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    print("Start pre-training discriminator......")
    #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
    run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)

    util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
    
    generated.generator_sample_example("sample_temp_positive", "sample_temp_negetive", 1000)

    generated.generator_test_sample_example("test_sample_temp_positive",
                                       "test_sample_temp_negetive",
                                       200)
    generated.generator_test_max_example("test_max_temp_positive",
                                       "test_max_temp_negetive",
                                       200)
    tf.logging.info("true data diversity: ")
    eva = Evaluate()
    eva.diversity_evaluate("test_sample_temp_positive" + "/*")



    print("Start adversial training......")
    whole_decay = False
    for epoch in range(1):
        batches = batcher.get_batches(mode='train')
        for step in range(int(len(batches)/1000)):

            run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
            generated.generator_sample_example("sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 1000)
            #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

            tf.logging.info("test performance: ")
            tf.logging.info("epoch: "+str(epoch)+" step: "+str(step))
            generated.generator_test_sample_example(
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200)
            generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                                            "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive",
                                            200)

            dis_batcher.train_queue = []
            dis_batcher.train_queue = []
            for i in range(epoch+1):
              for j in range(step+1):
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*")
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negetive/*")
            dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True)

            #dis_batcher.valid_batch = dis_batcher.train_batch
            whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"),
                                                  sess_dis, saver_dis, train_dir_dis, whole_decay)

  '''elif hps_generator.mode == 'decode':
    decode_model_hps = hps_generator  # This will be the hyperparameters for the decoder model
    model = Generator(decode_model_hps, vocab)
    generated = Generated_sample(model, vocab, batcher)
    bleu_score = generated.compute_BLEU()'=
    tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen'''

  else:
示例#6
0
    def generator_max_example(self, target_batches, positive_dir,
                              negetive_dir):

        self.temp_positive_dir = positive_dir
        self.temp_negetive_dir = negetive_dir

        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negetive_dir):
            os.mkdir(self.temp_negetive_dir)
        shutil.rmtree(self.temp_negetive_dir)
        shutil.rmtree(self.temp_positive_dir)
        if not os.path.exists(self.temp_positive_dir):
            os.mkdir(self.temp_positive_dir)
        if not os.path.exists(self.temp_negetive_dir):
            os.mkdir(self.temp_negetive_dir)
        counter = 0
        batches = target_batches
        step = 0

        while step < len(target_batches):

            batch = batches[step]
            step += 1

            decode_result = self._model.max_generator(self._sess, batch)

            for i in range(FLAGS.batch_size):

                original_review = batch.orig_outputs[i]

                output_ids = [int(t)
                              for t in decode_result['generated'][i]][0:]
                decoded_words = data.outputids2words(output_ids, self._vocab,
                                                     None)
                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words

                decoded_output = ' '.join(
                    decoded_words).strip()  # single string

                try:
                    fst_stop_idx = decoded_output.index(
                        data.STOP_DECODING_DOCUMENT
                    )  # index of the (first) [STOP] symbol
                    decoded_output = decoded_output[:fst_stop_idx]
                except ValueError:
                    decoded_output = decoded_output
                    decoded_output = decoded_output.replace("[UNK] ", "")
                    decoded_output = decoded_output.replace("[UNK]", "")
                    decoded_output, _ = re.subn(r"(! ){2,}", "! ",
                                                decoded_output)
                    decoded_output, _ = re.subn(r"(\. ){2,}", ". ",
                                                decoded_output)
                self.write_negtive_temp_to_json(positive_dir, negetive_dir,
                                                original_review,
                                                decoded_output)

        eva = Evaluate()
        eva.diversity_evaluate(negetive_dir + "/*")